diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..128d516a0c91a6715b8bf8d03437c413e89b5d85
--- /dev/null
+++ b/app.py
@@ -0,0 +1,373 @@
+import os
+os.system("cd open_flamingo && pip install .")
+import numpy as np
+import torch
+from PIL import Image
+
+
+import string
+import cv2
+
+
+import gradio as gr
+import torch
+from PIL import Image
+from huggingface_hub import hf_hub_download, login
+
+from open_flamingo.src.factory import create_model_and_transforms
+flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
+        "ViT-L-14",
+        "datacomp_xl_s13b_b90k",
+        "EleutherAI/pythia-1.4b",
+        "EleutherAI/pythia-1.4b",
+        add_visual_grounding=True,
+        location_token_num=1000,
+        add_visual_token = True,
+        use_format_v2 = True,
+    )
+
+checkpoint_path = hf_hub_download("chendl/compositional_test", "pythiaS.pt")
+checkpoint = torch.load(checkpoint_path, map_location="cpu")
+model_state_dict = {}
+for key in checkpoint.keys():
+    model_state_dict[key.replace("module.", "")] = checkpoint[key]
+if "vision_encoder.logit_scale"in model_state_dict:
+    # previous checkpoint has some unnecessary weights
+    del model_state_dict["vision_encoder.logit_scale"]
+    del model_state_dict["vision_encoder.visual.proj"]
+    del model_state_dict["vision_encoder.visual.ln_post.weight"]
+    del model_state_dict["vision_encoder.visual.ln_post.bias"]
+flamingo.load_state_dict(model_state_dict, strict=True)
+
+def get_outputs(
+    model,
+    batch_images,
+    attention_mask,
+    max_generation_length,
+    min_generation_length,
+    num_beams,
+    length_penalty,
+    input_ids,
+    image_start_index_list=None,
+    image_nums=None,
+    bad_words_ids=None,
+):
+    #  and torch.cuda.amp.autocast(dtype=torch.float16)
+    with torch.inference_mode():
+        outputs = model.generate(
+            batch_images,
+            input_ids,
+            attention_mask=attention_mask,
+            max_new_tokens=max_generation_length,
+            min_length=min_generation_length,
+            num_beams=num_beams,
+            length_penalty=length_penalty,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+            bad_words_ids=bad_words_ids,
+        )
+
+    return outputs
+
+
+def evaluate_refcoco(
+        model,
+        tokenizer,
+        image_processor,
+        batch_size,
+        tsvfile,
+        max_generation_length=20,
+        num_beams=3,
+        length_penalty=-2.0,
+        device=-1,
+        vis_embed_size=None,
+        rank=0,
+        world_size=1,
+        id=0,
+):
+    model.eval().cuda()
+    loc_token_ids = []
+    for i in range(1000):
+        loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    # all_ids = set(range(model.lang_encoder.lm_head.out_features))
+    # bad_words_ids = list(all_ids - set(loc_token_ids))
+    # bad_words_ids = [[b] for b in bad_words_ids]
+    # min_loc_token_id = min(loc_token_ids)
+    # max_loc_token_id = max(loc_token_ids)
+    total = 0
+    correct = 0
+    ious = []
+    if "refcocog" in tsvfile:
+        dataset_name = "refcocog"
+    elif "refcocoplus" in tsvfile:
+        dataset_name = "refcocoplus"
+    else:
+        dataset_name = "refcoco"
+    with open(tsvfile, "r") as f:
+        lines = f.readlines()
+        pbar = tqdm(lines, disable=(rank != 0))
+        for ii, line in enumerate(pbar):
+            if ii % world_size != rank:
+                continue
+            total += 1
+            line = line.rstrip()
+            uniq_id, image_id, text, region_coord, image = line.split("\t")
+
+            image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
+
+            gt_box = np.array(list(map(float, region_coord.split(","))))
+            width = image.width
+            height = image.height
+            image = image.resize((224, 224))
+            gt_box = gt_box / np.array([width, height, width, height]) * 224
+            batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+            prompt = [
+                f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
+
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            # attention_mask[input_ids == prebox_token_id] = 0
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+
+            model.debug_id = 0
+            with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=None,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=None,
+                    add_box=False,
+                )
+            boxes = outputs["boxes"]
+            scores = outputs["scores"]
+            if len(scores) > 0:
+                box = boxes[scores.argmax()]
+                iou = get_iou(box, gt_box)
+            else:
+                iou = 0.0
+                # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
+                tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}")
+            if iou >= 0.5:
+                correct += 1
+            pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
+            # open_cv_image = np.array(image)
+            # # Convert RGB to BGR
+            # open_cv_image = open_cv_image[:, :, ::-1].copy()
+            # for box, score in zip(boxes, scores):
+            #     open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
+            # cv2.imwrite("output.jpg", open_cv_image)
+            # print(boxes)
+            # print(scores)
+            # exit()
+
+
+def generate(
+    idx,
+    image,
+    text,
+    vis_embed_size=256,
+    rank=0,
+    world_size=1,
+):
+    if image is None:
+        raise gr.Error("Please upload an image.")
+    flamingo.eval()
+    loc_token_ids = []
+    for i in range(1000):
+        loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+
+    image_ori = image
+    image = image.convert("RGB")
+    width = image.width
+    height = image.height
+    image = image.resize((224, 224))
+    batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+    if idx == 1:
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
+        bad_words_ids = None
+        max_generation_length = 5
+    else:
+        prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
+        bad_words_ids = loc_word_ids
+        max_generation_length = 30
+    encodings = tokenizer(
+        prompt,
+        padding="longest",
+        truncation=True,
+        return_tensors="pt",
+        max_length=2000,
+    )
+    input_ids = encodings["input_ids"]
+    attention_mask = encodings["attention_mask"]
+    image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+    image_start_index_list = [[x] for x in image_start_index_list]
+    image_nums = [1] * len(input_ids)
+    outputs = get_outputs(
+        model=flamingo,
+        batch_images=batch_images,
+        attention_mask=attention_mask,
+        max_generation_length=max_generation_length,
+        min_generation_length=4,
+        num_beams=1,
+        length_penalty=1.0,
+        input_ids=input_ids,
+        bad_words_ids=bad_words_ids,
+        image_start_index_list=image_start_index_list,
+        image_nums=image_nums,
+    )
+    boxes = outputs["boxes"]
+    scores = outputs["scores"]
+    if len(scores) > 0:
+        box = boxes[scores.argmax()]
+        iou = get_iou(box, gt_box)
+    else:
+        iou = 0.0
+        # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
+        tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}")
+    if iou >= 0.5:
+        correct += 1
+
+
+    gen_text = tokenizer.batch_decode(outputs)
+    if idx == 1:
+        return f"Output:{gen_text}", out_image
+    elif idx == 2:
+        return (f"Question: {text.strip()} Answer: {gen_text}")
+    else:
+        return (f"Output:{gen_text}")
+
+
+with gr.Blocks() as demo:
+    gr.Markdown(
+        """
+    🍜 Object Centric Pretraining Demo  
+    In this demo we showcase the in-context learning and grounding capabilities of the Object-Centric Pretrained model, a large multimodal model. Note that we add two additional demonstrations to the ones presented to improve the demo experience.
+    The model is trained on an interleaved mixture of text, images and bounding box and is able to generate text conditioned on sequences of images/text.
+    """
+    )
+
+    with gr.Accordion("See terms and conditions"):
+        gr.Markdown(
+            """**Please read the following information carefully before proceeding.**This demo does NOT store any personal information on its users, and it does NOT store user queries.""")
+
+    with gr.Tab("📷 Image Captioning"):
+        with gr.Row():
+
+
+            query_image = gr.Image(type="pil")
+        with gr.Row():
+            chat_input = gr.Textbox(lines=1, label="Chat Input")
+        text_output = gr.Textbox(value="Output:", label="Model output")
+
+        run_btn = gr.Button("Run model")
+
+
+
+        def on_click_fn(img,text): return generate(0, img, text)
+
+        run_btn.click(on_click_fn, inputs=[query_image,chat_input], outputs=[text_output])
+
+    with gr.Tab("🦓 Grounding"):
+        with gr.Row():
+            with gr.Column(scale=1):
+                query_image = gr.Image(type="pil")
+            with gr.Column(scale=1):
+                out_image = gr.Image(type="pil")
+        with gr.Row():
+            chat_input = gr.Textbox(lines=1, label="Chat Input")
+        text_output = gr.Textbox(value="Output:", label="Model output")
+
+        run_btn = gr.Button("Run model")
+
+
+        def on_click_fn(img, text): return generate(1, img, text)
+
+
+        run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output, out_image])
+
+    with gr.Tab("🔢 Counting objects"):
+        with gr.Row():
+            query_image = gr.Image(type="pil")
+        with gr.Row():
+            chat_input = gr.Textbox(lines=1, label="Chat Input")
+        text_output = gr.Textbox(value="Output:", label="Model output")
+
+        run_btn = gr.Button("Run model")
+
+
+        def on_click_fn(img,text): return generate(0, img, text)
+
+
+        run_btn.click(on_click_fn, inputs=[query_image, chat_input], outputs=[text_output])
+
+    with gr.Tab("🕵️ Visual Question Answering"):
+        with gr.Row():
+            query_image = gr.Image(type="pil")
+        with gr.Row():
+            question = gr.Textbox(lines=1, label="Question")
+        text_output = gr.Textbox(value="Output:", label="Model output")
+
+        run_btn = gr.Button("Run model")
+
+
+        def on_click_fn(img, txt): return generate(2, img, txt)
+
+
+        run_btn.click(
+            on_click_fn, inputs=[query_image, question], outputs=[text_output]
+        )
+
+    with gr.Tab("🌎 Custom"):
+        gr.Markdown(
+            """### Customize the demonstration by uploading your own images and text samples. 
+                    ### **Note: Any text prompt you use will be prepended with an 'Output:', so you don't need to include it in your prompt.**"""
+        )
+        with gr.Row():
+            query_image = gr.Image(type="pil")
+        with gr.Row():
+            question = gr.Textbox(lines=1, label="Question")
+        text_output = gr.Textbox(value="Output:", label="Model output")
+
+        run_btn = gr.Button("Run model")
+
+
+        def on_click_fn(img, txt): return generate(2, img, txt)
+
+
+        run_btn.click(
+            on_click_fn, inputs=[query_image, question], outputs=[text_output]
+        )
+
+demo.queue(concurrency_count=1)
+demo.launch()
diff --git a/multimodal/HISTORY.md b/multimodal/HISTORY.md
new file mode 100644
index 0000000000000000000000000000000000000000..556720509176152deea697bddb9070a138143888
--- /dev/null
+++ b/multimodal/HISTORY.md
@@ -0,0 +1,3 @@
+## 1.0.0
+
+* it works
\ No newline at end of file
diff --git a/multimodal/LICENSE b/multimodal/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..206be3ebbf3a41276af664447106615b0a954814
--- /dev/null
+++ b/multimodal/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner,  Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe,  Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith,  Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/multimodal/MODEL_CARD.md b/multimodal/MODEL_CARD.md
new file mode 100644
index 0000000000000000000000000000000000000000..b1264ae72debb5cc083e995ceca73e8534422302
--- /dev/null
+++ b/multimodal/MODEL_CARD.md
@@ -0,0 +1,44 @@
+---
+language: en
+datasets:
+- laion2b
+---
+
+# OpenFlamingo-9B
+
+[Blog post]() | [Code](https://github.com/mlfoundations/open_flamingo) | [Demo](https://7164d2142d11.ngrok.app)
+
+OpenFlamingo is an open source implementation of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) models. 
+OpenFlamingo-9B is built off of [CLIP ViT-L/14](https://huggingface.co/openai/clip-vit-large-patch14) and [LLaMA-7B](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/).
+
+
+## Model Details
+We freeze the pretrained vision encoder and language model, and then we train connecting Perceiver modules and cross-attention layers, following the original Flamingo paper. 
+
+Our training data is a mixture of [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and a large interleaved image-text dataset called Multimodal C4, which will be released soon.
+
+The current model is an early checkpoint of an ongoing effort. This checkpoint has seen 5 million interleaved image-text examples from Multimodal C4 and 10 million samples from LAION 2B.
+
+## Uses
+OpenFlamingo-9B is intended to be used **for academic research purposes only.** Commercial use is prohibited, in line with LLaMA's non-commercial license.
+
+### Bias, Risks, and Limitations
+This model may generate inaccurate or offensive outputs, reflecting biases in its training data and pretrained priors. 
+
+In an effort to mitigate current potential biases and harms, we have deployed a text content filter on model outputs in the OpenFlamingo demo. We continue to red-team the model to understand and improve its safety.
+
+## Evaluation
+We've evaluated this checkpoint on the validation sets for two vision-language tasks: COCO captioning and VQAv2. Results are displayed below.
+
+**COCO (CIDEr)**
+
+|0-shot|4-shot|8-shot|16-shot|32-shot|
+|--|--|--|--|--|
+|65.52|74.28|79.26|81.84|84.52|
+
+
+**VQAv2 (VQA accuracy)**
+
+|0-shot|4-shot|8-shot|16-shot|32-shot|
+|---|---|---|---|---|
+|43.55|44.05|47.5|48.87|50.34|
diff --git a/multimodal/Makefile b/multimodal/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..d5cc3840bce9ce0e5aebc435f63ffa5b534d4a8f
--- /dev/null
+++ b/multimodal/Makefile
@@ -0,0 +1,19 @@
+install: ## [Local development] Upgrade pip, install requirements, install package.
+	python -m pip install -U pip
+	python -m pip install -e .
+
+install-dev: ## [Local development] Install test requirements
+	python -m pip install -r requirements-test.txt
+
+lint: ## [Local development] Run mypy, pylint and black
+	python -m mypy open_flamingo
+	python -m pylint open_flamingo
+	python -m black --check -l 120 open_flamingo
+
+black: ## [Local development] Auto-format python code using black
+	python -m black -l 120 .
+
+.PHONY: help
+
+help: # Run `make help` to get help on the make commands
+	@grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
\ No newline at end of file
diff --git a/multimodal/README.md b/multimodal/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e789675fd3dc0168eeb83756c6f24763451eeac7
--- /dev/null
+++ b/multimodal/README.md
@@ -0,0 +1,233 @@
+# 🦩 OpenFlamingo
+
+[![PyPI version](https://badge.fury.io/py/open_flamingo.svg)](https://badge.fury.io/py/open_flamingo)
+
+[Blog post](https://laion.ai/blog/open-flamingo/) | Paper (coming soon)
+
+Welcome to our open source version of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) model! In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models. We also provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) trained on a new Multimodal C4 dataset (coming soon). Please refer to our blog post for more details.
+
+This repo is still under development, and we hope to release better performing and larger OpenFlamingo models soon. If you have any questions, please feel free to open an issue. We also welcome contributions!
+
+# Table of Contents
+- [Installation](#installation)
+- [Approach](#approach)
+  * [Model architecture](#model-architecture)
+- [Usage](#usage)
+  * [Initializing an OpenFlamingo model](#initializing-an-openflamingo-model)
+  * [Generating text](#generating-text)
+- [Training](#training)
+  * [Dataset](#dataset)
+- [Evaluation](#evaluation)
+- [Future plans](#future-plans)
+- [Team](#team)
+- [Acknowledgments](#acknowledgments)
+- [Citing](#citing)
+
+# Installation
+
+To install the package in an existing environment, run 
+```
+pip install open-flamingo
+```
+
+or to create a conda environment for running OpenFlamingo, run
+```
+conda env create -f environment.yml
+```
+
+# Usage
+We provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) using a CLIP ViT-Large vision encoder and a LLaMA-7B language model. In general, we support any [CLIP vision encoder](https://huggingface.co/models?search=clip). For the language model, we support [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models.
+
+#### NOTE: To use LLaMA models, you will need to install the latest version of transformers via
+```
+pip install git+https://github.com/huggingface/transformers
+```
+Use this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to HuggingFace format.
+
+## Initializing an OpenFlamingo model
+``` python
+from open_flamingo import create_model_and_transforms
+
+model, image_processor, tokenizer = create_model_and_transforms(
+    clip_vision_encoder_path="ViT-L-14",
+    clip_vision_encoder_pretrained="openai",
+    lang_encoder_path="<path to llama weights in HuggingFace format>",
+    tokenizer_path="<path to llama tokenizer in HuggingFace format>",
+    cross_attn_every_n_layers=4
+)
+
+# grab model checkpoint from huggingface hub
+from huggingface_hub import hf_hub_download
+import torch
+
+checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B", "checkpoint.pt")
+model.load_state_dict(torch.load(checkpoint_path), strict=False)
+```
+
+## Generating text
+Here is an example of generating text conditioned on interleaved images/text, in this case we will do few-shot image captioning.
+
+``` python
+from PIL import Image
+import requests
+
+"""
+Step 1: Load images
+"""
+demo_image_one = Image.open(
+    requests.get(
+        "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
+    ).raw
+)
+
+demo_image_two = Image.open(
+    requests.get(
+        "http://images.cocodataset.org/test-stuff2017/000000028137.jpg",
+        stream=True
+    ).raw
+)
+
+query_image = Image.open(
+    requests.get(
+        "http://images.cocodataset.org/test-stuff2017/000000028352.jpg", 
+        stream=True
+    ).raw
+)
+
+
+"""
+Step 2: Preprocessing images
+Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 
+ batch_size x num_media x num_frames x channels x height x width. 
+ In this case batch_size = 1, num_media = 3, num_frames = 1 
+ (this will always be one expect for video which we don't support yet), 
+ channels = 3, height = 224, width = 224.
+"""
+vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
+vision_x = torch.cat(vision_x, dim=0)
+vision_x = vision_x.unsqueeze(1).unsqueeze(0)
+
+"""
+Step 3: Preprocessing text
+Details: In the text we expect an <|#image#|> special token to indicate where an image is.
+ We also expect an <|endofchunk|> special token to indicate the end of the text 
+ portion associated with an image.
+"""
+tokenizer.padding_side = "left" # For generation padding tokens should be on the left
+lang_x = tokenizer(
+    ["<|#image#|>An image of two cats.<|endofchunk|><|#image#|>An image of a bathroom sink.<|endofchunk|><|#image#|>An image of"],
+    return_tensors="pt",
+)
+
+
+"""
+Step 4: Generate text
+"""
+generated_text = model.generate(
+    vision_x=vision_x,
+    lang_x=lang_x["input_ids"],
+    attention_mask=lang_x["attention_mask"],
+    max_new_tokens=20,
+    num_beams=3,
+)
+
+print("Generated text: ", tokenizer.decode(generated_text[0]))
+```
+
+# Approach
+OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context training.
+
+## Model architecture
+OpenFlamingo seeks to fuse a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below.
+
+![OpenFlamingo architecture](docs/flamingo.png) 
+Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)
+
+# Training
+To train a model, modify the following example command, which uses OPT 1.3B as an example LM:
+```
+torchrun --nnodes=1 --nproc_per_node=4 train.py \
+--run_name flamingo3B \
+--lm_path facebook/opt-1.3b \
+--tokenizer_path facebook/opt-1.3b \
+--dataset_resampled \
+--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
+--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
+--batch_size_mmc4 4 \
+--batch_size_laion 8 \
+--train_num_samples_mmc4 125000 \
+--train_num_samples_laion 250000 \
+--loss_multiplier_laion 0.2 \
+--workers=6 \
+--num_epochs 250 \
+--lr_scheduler constant \
+--warmup_steps 5000 \
+--use_media_placement_augmentation \
+--mmc4_textsim_threshold 30
+```
+
+## Dataset
+We expect all our training datasets to be [WebDataset](https://github.com/webdataset/webdataset) shards.
+We train our models on the [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and Multimodal C4 (coming soon) datasets. By default the LAION 2B dataset is in WebDataset format if it is downloaded using the [img2dataset tool](https://github.com/rom1504/img2dataset) and Multimodal C4 comes packaged in the WebDataset format.
+
+
+# Evaluation
+We currently support running evaluations on [COCO](https://cocodataset.org/#home), [VQAv2](https://visualqa.org/index.html), [OKVQA](https://okvqa.allenai.org), [Flickr30k](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset), and [ImageNet](https://image-net.org/index.php). Note that currently these evaluations are ran in validation mode (as specified in the Flamingo paper). We will be adding support for running evaluations in test mode in the future.
+
+Before evaluating the model, you will need to install the coco evaluation package by running the following command:
+```
+pip install pycocoevalcap
+```
+
+To run evaluations on OKVQA you will need to run the following command:
+```
+import nltk
+nltk.download('wordnet')
+```
+
+To evaluate the model, run the script at `open_flamingo/scripts/run_eval.sh`
+
+# Future plans
+- [ ] Add support for video input
+- [ ] Release better performing and larger OpenFlamingo models
+- [ ] Expand our evaluation suite
+- [ ] Add support for FSDP training
+
+# Team
+
+OpenFlamingo is developed by:
+
+[Anas Awadalla](https://anas-awadalla.streamlit.app/), [Irena Gao](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/).
+
+The team is primarily from the University of Washington, Stanford, AI2, UCSB, and Google.
+
+# Acknowledgments
+This code is based on Lucidrains' [flamingo implementation](https://github.com/lucidrains/flamingo-pytorch) and David Hansmair's [flamingo-mini repo](https://github.com/dhansmair/flamingo-mini). Thank you for making your code public! We also thank the [OpenCLIP](https://github.com/mlfoundations/open_clip) team as we use their data loading code and take inspiration from their library design.
+
+We would also like to thank [Jean-Baptiste Alayrac](https://www.jbalayrac.com) and [Antoine Miech](https://antoine77340.github.io) for their advice, [Rohan Taori](https://www.rohantaori.com/), [Nicholas Schiefer](https://nicholasschiefer.com/), [Deep Ganguli](https://hai.stanford.edu/people/deep-ganguli), [Thomas Liao](https://thomasliao.com/), [Tatsunori Hashimoto](https://thashim.github.io/), and [Nicholas Carlini](https://nicholas.carlini.com/) for their help with assessing the safety risks of our release, and to [Stability AI](https://stability.ai) for providing us with compute resources to train these models.
+
+# Citing
+If you found this repository useful, please consider citing:
+
+```
+@software{anas_awadalla_2023_7733589,
+  author = {Awadalla, Anas and Gao, Irena and Gardner, Joshua and Hessel, Jack and Hanafy, Yusuf and Zhu, Wanrong and Marathe, Kalyani and Bitton, Yonatan and Gadre, Samir and Jitsev, Jenia and Kornblith, Simon and Koh, Pang Wei and Ilharco, Gabriel and Wortsman, Mitchell and Schmidt, Ludwig},
+  title = {OpenFlamingo},
+  month        = mar,
+  year         = 2023,
+  publisher    = {Zenodo},
+  version      = {v0.1.1},
+  doi          = {10.5281/zenodo.7733589},
+  url          = {https://doi.org/10.5281/zenodo.7733589}
+}
+```
+
+```
+@article{Alayrac2022FlamingoAV,
+  title={Flamingo: a Visual Language Model for Few-Shot Learning},
+  author={Jean-Baptiste Alayrac and Jeff Donahue and Pauline Luc and Antoine Miech and Iain Barr and Yana Hasson and Karel Lenc and Arthur Mensch and Katie Millican and Malcolm Reynolds and Roman Ring and Eliza Rutherford and Serkan Cabi and Tengda Han and Zhitao Gong and Sina Samangooei and Marianne Monteiro and Jacob Menick and Sebastian Borgeaud and Andy Brock and Aida Nematzadeh and Sahand Sharifzadeh and Mikolaj Binkowski and Ricardo Barreira and Oriol Vinyals and Andrew Zisserman and Karen Simonyan},
+  journal={ArXiv},
+  year={2022},
+  volume={abs/2204.14198}
+}
+```
diff --git a/multimodal/YOLOX/.gitignore b/multimodal/YOLOX/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..9842565a17ef40856a03a8cf7a4c7b672520a868
--- /dev/null
+++ b/multimodal/YOLOX/.gitignore
@@ -0,0 +1,228 @@
+### Linux ###
+*~
+
+# user experiments directory
+YOLOX_outputs/
+datasets/
+# do not ignore datasets under yolox/data
+!*yolox/data/datasets/
+
+# temporary files which can be created if a process still has a handle open of a deleted file
+.fuse_hidden*
+
+# KDE directory preferences
+.directory
+
+# Linux trash folder which might appear on any partition or disk
+.Trash-*
+
+# .nfs files are created when an open file is removed but is still being accessed
+.nfs*
+
+### PyCharm ###
+# User-specific stuff
+.idea
+
+# CMake
+cmake-build-*/
+
+# Mongo Explorer plugin
+.idea/**/mongoSettings.xml
+
+# File-based project format
+*.iws
+
+# IntelliJ
+out/
+
+# mpeltonen/sbt-idea plugin
+.idea_modules/
+
+# JIRA plugin
+atlassian-ide-plugin.xml
+
+# Cursive Clojure plugin
+.idea/replstate.xml
+
+# Crashlytics plugin (for Android Studio and IntelliJ)
+com_crashlytics_export_strings.xml
+crashlytics.properties
+crashlytics-build.properties
+fabric.properties
+
+# Editor-based Rest Client
+.idea/httpRequests
+
+# Android studio 3.1+ serialized cache file
+.idea/caches/build_file_checksums.ser
+
+# JetBrains templates
+**___jb_tmp___
+
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+docs/build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don’t work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+### Vim ###
+# Swap
+[._]*.s[a-v][a-z]
+[._]*.sw[a-p]
+[._]s[a-rt-v][a-z]
+[._]ss[a-gi-z]
+[._]sw[a-p]
+
+# Session
+Session.vim
+
+# Temporary
+.netrwhist
+# Auto-generated tag files
+tags
+# Persistent undo
+[._]*.un~
+
+# output
+docs/api
+.code-workspace.code-workspace
+*.pkl
+*.npy
+*.pth
+*.onnx
+*.engine
+events.out.tfevents*
+
+# vscode
+*.code-workspace
+.vscode
+
+# vim
+.vim
+
+# OS generated files
+.DS_Store
+.DS_Store?
+.Trashes
+ehthumbs.db
+Thumbs.db
diff --git a/multimodal/YOLOX/.pre-commit-config.yaml b/multimodal/YOLOX/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5120983f908eae6415390f17c58194c671e59899
--- /dev/null
+++ b/multimodal/YOLOX/.pre-commit-config.yaml
@@ -0,0 +1,43 @@
+repos:
+  - repo: https://github.com/pycqa/flake8
+    rev: 3.8.3
+    hooks:
+      - id: flake8
+  - repo: https://github.com/pre-commit/pre-commit-hooks
+    rev: v3.1.0
+    hooks:
+      - id: check-added-large-files
+      - id: check-docstring-first
+      - id: check-executables-have-shebangs
+      - id: check-json
+      - id: check-yaml
+        args: ["--unsafe"]
+      - id: debug-statements
+      - id: end-of-file-fixer
+      - id: requirements-txt-fixer
+      - id: trailing-whitespace
+  - repo: https://github.com/jorisroovers/gitlint
+    rev: v0.15.1
+    hooks:
+      - id: gitlint
+  - repo: https://github.com/pycqa/isort
+    rev: 4.3.21
+    hooks:
+      - id: isort
+
+  - repo: https://github.com/PyCQA/autoflake
+    rev: v1.4
+    hooks:
+      - id: autoflake
+        name: Remove unused variables and imports
+        entry: autoflake
+        language: python
+        args:
+          [
+            "--in-place",
+            "--remove-all-unused-imports",
+            "--remove-unused-variables",
+            "--expand-star-imports",
+            "--ignore-init-module-imports",
+          ]
+        files: \.py$
diff --git a/multimodal/YOLOX/.readthedocs.yaml b/multimodal/YOLOX/.readthedocs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7e77c229649f01859f5f2945ebd002e52d1f1835
--- /dev/null
+++ b/multimodal/YOLOX/.readthedocs.yaml
@@ -0,0 +1,21 @@
+# .readthedocs.yaml
+# Read the Docs configuration file
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+# Required
+version: 2
+
+# Build documentation in the docs/ directory with Sphinx
+sphinx:
+   configuration: docs/conf.py
+
+# Optionally build your docs in additional formats such as PDF
+formats:
+   - pdf
+
+# Optionally set the version of Python and requirements required to build your docs
+python:
+   version: "3.7"
+   install:
+   - requirements: docs/requirements-doc.txt
+   - requirements: requirements.txt
diff --git a/multimodal/YOLOX/LICENSE b/multimodal/YOLOX/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..1d4dc763d3d33d3722c6d86054c01b8a459bb2ea
--- /dev/null
+++ b/multimodal/YOLOX/LICENSE
@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "{}"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright (c) 2021-2022 Megvii Inc. All rights reserved.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
diff --git a/multimodal/YOLOX/MANIFEST.in b/multimodal/YOLOX/MANIFEST.in
new file mode 100644
index 0000000000000000000000000000000000000000..aea4f44a71f15b2faf0b70f0583ffdad7e557f3f
--- /dev/null
+++ b/multimodal/YOLOX/MANIFEST.in
@@ -0,0 +1,2 @@
+include requirements.txt
+recursive-include yolox *.cpp *.h *.cu *.cuh *.cc
diff --git a/multimodal/YOLOX/README.md b/multimodal/YOLOX/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..18d9d404048c7a3b85a2e9a400ed2d6bfa772ae9
--- /dev/null
+++ b/multimodal/YOLOX/README.md
@@ -0,0 +1,255 @@
+<div align="center"><img src="assets/logo.png" width="350"></div>
+<img src="assets/demo.png" >
+
+## Introduction
+YOLOX is an anchor-free version of YOLO, with a simpler design but better performance! It aims to bridge the gap between research and industrial communities.
+For more details, please refer to our [report on Arxiv](https://arxiv.org/abs/2107.08430).
+
+This repo is an implementation of PyTorch version YOLOX, there is also a [MegEngine implementation](https://github.com/MegEngine/YOLOX).
+
+<img src="assets/git_fig.png" width="1000" >
+
+## Updates!!
+* 【2023/02/28】 We support assignment visualization tool, see doc [here](./docs/assignment_visualization.md).
+* 【2022/04/14】 We support jit compile op.
+* 【2021/08/19】 We optimize the training process with **2x** faster training and **~1%** higher performance! See [notes](docs/updates_note.md) for more details.
+* 【2021/08/05】 We release [MegEngine version YOLOX](https://github.com/MegEngine/YOLOX).
+* 【2021/07/28】 We fix the fatal error of [memory leak](https://github.com/Megvii-BaseDetection/YOLOX/issues/103)
+* 【2021/07/26】 We now support [MegEngine](https://github.com/Megvii-BaseDetection/YOLOX/tree/main/demo/MegEngine) deployment.
+* 【2021/07/20】 We have released our technical report on [Arxiv](https://arxiv.org/abs/2107.08430).
+
+## Coming soon
+- [ ] YOLOX-P6 and larger model.
+- [ ] Objects365 pretrain.
+- [ ] Transformer modules.
+- [ ] More features in need.
+
+## Benchmark
+
+#### Standard Models.
+
+|Model |size |mAP<sup>val<br>0.5:0.95 |mAP<sup>test<br>0.5:0.95 | Speed V100<br>(ms) | Params<br>(M) |FLOPs<br>(G)| weights |
+| ------        |:---: | :---:    | :---:       |:---:     |:---:  | :---: | :----: |
+|[YOLOX-s](./exps/default/yolox_s.py)    |640  |40.5 |40.5      |9.8      |9.0 | 26.8 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth) |
+|[YOLOX-m](./exps/default/yolox_m.py)    |640  |46.9 |47.2      |12.3     |25.3 |73.8| [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_m.pth) |
+|[YOLOX-l](./exps/default/yolox_l.py)    |640  |49.7 |50.1      |14.5     |54.2| 155.6 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_l.pth) |
+|[YOLOX-x](./exps/default/yolox_x.py)   |640   |51.1 |**51.5**  | 17.3    |99.1 |281.9 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x.pth) |
+|[YOLOX-Darknet53](./exps/default/yolov3.py)   |640  | 47.7 | 48.0 | 11.1 |63.7 | 185.3 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_darknet.pth) |
+
+<details>
+<summary>Legacy models</summary>
+
+|Model |size |mAP<sup>test<br>0.5:0.95 | Speed V100<br>(ms) | Params<br>(M) |FLOPs<br>(G)| weights |
+| ------        |:---: | :---:       |:---:     |:---:  | :---: | :----: |
+|[YOLOX-s](./exps/default/yolox_s.py)    |640  |39.6      |9.8     |9.0 | 26.8 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EW62gmO2vnNNs5npxjzunVwB9p307qqygaCkXdTO88BLUg?e=NMTQYw)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_s.pth) |
+|[YOLOX-m](./exps/default/yolox_m.py)    |640  |46.4      |12.3     |25.3 |73.8| [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/ERMTP7VFqrVBrXKMU7Vl4TcBQs0SUeCT7kvc-JdIbej4tQ?e=1MDo9y)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_m.pth) |
+|[YOLOX-l](./exps/default/yolox_l.py)    |640  |50.0  |14.5 |54.2| 155.6 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EWA8w_IEOzBKvuueBqfaZh0BeoG5sVzR-XYbOJO4YlOkRw?e=wHWOBE)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_l.pth) |
+|[YOLOX-x](./exps/default/yolox_x.py)   |640  |**51.2**      | 17.3 |99.1 |281.9 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EdgVPHBziOVBtGAXHfeHI5kBza0q9yyueMGdT0wXZfI1rQ?e=tABO5u)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_x.pth) |
+|[YOLOX-Darknet53](./exps/default/yolov3.py)   |640  | 47.4      | 11.1 |63.7 | 185.3 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EZ-MV1r_fMFPkPrNjvbJEMoBLOLAnXH-XKEB77w8LhXL6Q?e=mf6wOc)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_darknet53.pth) |
+
+</details>
+
+#### Light Models.
+
+|Model |size |mAP<sup>val<br>0.5:0.95 | Params<br>(M) |FLOPs<br>(G)| weights |
+| ------        |:---:  |  :---:       |:---:     |:---:  | :---: |
+|[YOLOX-Nano](./exps/default/yolox_nano.py) |416  |25.8  | 0.91 |1.08 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_nano.pth) |
+|[YOLOX-Tiny](./exps/default/yolox_tiny.py) |416  |32.8 | 5.06 |6.45 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_tiny.pth) |
+
+
+<details>
+<summary>Legacy models</summary>
+
+|Model |size |mAP<sup>val<br>0.5:0.95 | Params<br>(M) |FLOPs<br>(G)| weights |
+| ------        |:---:  |  :---:       |:---:     |:---:  | :---: |
+|[YOLOX-Nano](./exps/default/yolox_nano.py) |416  |25.3  | 0.91 |1.08 | [github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_nano.pth) |
+|[YOLOX-Tiny](./exps/default/yolox_tiny.py) |416  |32.8 | 5.06 |6.45 | [github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_tiny_32dot8.pth) |
+
+</details>
+
+## Quick Start
+
+<details>
+<summary>Installation</summary>
+
+Step1. Install YOLOX from source.
+```shell
+git clone git@github.com:Megvii-BaseDetection/YOLOX.git
+cd YOLOX
+pip3 install -v -e .  # or  python3 setup.py develop
+```
+
+</details>
+
+<details>
+<summary>Demo</summary>
+
+Step1. Download a pretrained model from the benchmark table.
+
+Step2. Use either -n or -f to specify your detector's config. For example:
+
+```shell
+python tools/demo.py image -n yolox-s -c /path/to/your/yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device [cpu/gpu]
+```
+or
+```shell
+python tools/demo.py image -f exps/default/yolox_s.py -c /path/to/your/yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device [cpu/gpu]
+```
+Demo for video:
+```shell
+python tools/demo.py video -n yolox-s -c /path/to/your/yolox_s.pth --path /path/to/your/video --conf 0.25 --nms 0.45 --tsize 640 --save_result --device [cpu/gpu]
+```
+
+
+</details>
+
+<details>
+<summary>Reproduce our results on COCO</summary>
+
+Step1. Prepare COCO dataset
+```shell
+cd <YOLOX_HOME>
+ln -s /path/to/your/COCO ./datasets/COCO
+```
+
+Step2. Reproduce our results on COCO by specifying -n:
+
+```shell
+python -m yolox.tools.train -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
+                               yolox-m
+                               yolox-l
+                               yolox-x
+```
+* -d: number of gpu devices
+* -b: total batch size, the recommended number for -b is num-gpu * 8
+* --fp16: mixed precision training
+* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
+
+
+
+When using -f, the above commands are equivalent to:
+```shell
+python -m yolox.tools.train -f exps/default/yolox_s.py -d 8 -b 64 --fp16 -o [--cache]
+                               exps/default/yolox_m.py
+                               exps/default/yolox_l.py
+                               exps/default/yolox_x.py
+```
+
+**Multi Machine Training**
+
+We also support multi-nodes training. Just add the following args:
+* --num\_machines: num of your total training nodes
+* --machine\_rank: specify the rank of each node
+
+Suppose you want to train YOLOX on 2 machines, and your master machines's IP is 123.123.123.123, use port 12312 and TCP.
+
+On master machine, run
+```shell
+python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 0
+```
+On the second machine, run
+```shell
+python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 1
+```
+
+**Logging to Weights & Biases**
+
+To log metrics, predictions and model checkpoints to [W&B](https://docs.wandb.ai/guides/integrations/other/yolox) use the command line argument `--logger wandb` and use the prefix "wandb-" to specify arguments for initializing the wandb run.
+
+```shell
+python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
+                         yolox-m
+                         yolox-l
+                         yolox-x
+```
+
+An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0)
+
+**Others**
+
+See more information with the following command:
+```shell
+python -m yolox.tools.train --help
+```
+
+</details>
+
+
+<details>
+<summary>Evaluation</summary>
+
+We support batch testing for fast evaluation:
+
+```shell
+python -m yolox.tools.eval -n  yolox-s -c yolox_s.pth -b 64 -d 8 --conf 0.001 [--fp16] [--fuse]
+                               yolox-m
+                               yolox-l
+                               yolox-x
+```
+* --fuse: fuse conv and bn
+* -d: number of GPUs used for evaluation. DEFAULT: All GPUs available will be used.
+* -b: total batch size across on all GPUs
+
+To reproduce speed test, we use the following command:
+```shell
+python -m yolox.tools.eval -n  yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --fp16 --fuse
+                               yolox-m
+                               yolox-l
+                               yolox-x
+```
+
+</details>
+
+
+<details>
+<summary>Tutorials</summary>
+
+*  [Training on custom data](docs/train_custom_data.md)
+*  [Caching for custom data](docs/cache.md)
+*  [Manipulating training image size](docs/manipulate_training_image_size.md)
+*  [Assignment visualization](docs/assignment_visualization.md)
+*  [Freezing model](docs/freeze_module.md)
+
+</details>
+
+## Deployment
+
+
+1. [MegEngine in C++ and Python](./demo/MegEngine)
+2. [ONNX export and an ONNXRuntime](./demo/ONNXRuntime)
+3. [TensorRT in C++ and Python](./demo/TensorRT)
+4. [ncnn in C++ and Java](./demo/ncnn)
+5. [OpenVINO in C++ and Python](./demo/OpenVINO)
+6. [Accelerate YOLOX inference with nebullvm in Python](./demo/nebullvm)
+
+## Third-party resources
+* YOLOX for streaming perception: [StreamYOLO (CVPR 2022 Oral)](https://github.com/yancie-yjr/StreamYOLO)
+* The YOLOX-s and YOLOX-nano are Integrated into [ModelScope](https://www.modelscope.cn/home). Try out the Online Demo at [YOLOX-s](https://www.modelscope.cn/models/damo/cv_cspnet_image-object-detection_yolox/summary) and [YOLOX-Nano](https://www.modelscope.cn/models/damo/cv_cspnet_image-object-detection_yolox_nano_coco/summary) respectively 🚀.
+* Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Sultannn/YOLOX-Demo)
+* The ncnn android app with video support: [ncnn-android-yolox](https://github.com/FeiGeChuanShu/ncnn-android-yolox) from [FeiGeChuanShu](https://github.com/FeiGeChuanShu)
+* YOLOX with Tengine support: [Tengine](https://github.com/OAID/Tengine/blob/tengine-lite/examples/tm_yolox.cpp) from [BUG1989](https://github.com/BUG1989)
+* YOLOX + ROS2 Foxy: [YOLOX-ROS](https://github.com/Ar-Ray-code/YOLOX-ROS) from [Ar-Ray](https://github.com/Ar-Ray-code)
+* YOLOX Deploy DeepStream: [YOLOX-deepstream](https://github.com/nanmi/YOLOX-deepstream) from [nanmi](https://github.com/nanmi)
+* YOLOX MNN/TNN/ONNXRuntime: [YOLOX-MNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/mnn/cv/mnn_yolox.cpp)、[YOLOX-TNN](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/tnn/cv/tnn_yolox.cpp) and [YOLOX-ONNXRuntime C++](https://github.com/DefTruth/lite.ai.toolkit/blob/main/lite/ort/cv/yolox.cpp) from [DefTruth](https://github.com/DefTruth)
+* Converting darknet or yolov5 datasets to COCO format for YOLOX: [YOLO2COCO](https://github.com/RapidAI/YOLO2COCO) from [Daniel](https://github.com/znsoftm)
+
+## Cite YOLOX
+If you use YOLOX in your research, please cite our work by using the following BibTeX entry:
+
+```latex
+ @article{yolox2021,
+  title={YOLOX: Exceeding YOLO Series in 2021},
+  author={Ge, Zheng and Liu, Songtao and Wang, Feng and Li, Zeming and Sun, Jian},
+  journal={arXiv preprint arXiv:2107.08430},
+  year={2021}
+}
+```
+## In memory of Dr. Jian Sun
+Without the guidance of [Dr. Jian Sun](http://www.jiansun.org/), YOLOX would not have been released and open sourced to the community.
+The passing away of Dr. Jian is a huge loss to the Computer Vision field. We add this section here to express our remembrance and condolences to our captain Dr. Jian.
+It is hoped that every AI practitioner in the world will stick to the concept of "continuous innovation to expand cognitive boundaries, and extraordinary technology to achieve product value" and move forward all the way.
+
+<div align="center"><img src="assets/sunjian.png" width="200"></div>
+没有孙剑博士的指导,YOLOX也不会问世并开源给社区使用。
+孙剑博士的离去是CV领域的一大损失,我们在此特别添加了这个部分来表达对我们的“船长”孙老师的纪念和哀思。
+希望世界上的每个AI从业者秉持着“持续创新拓展认知边界,非凡科技成就产品价值”的观念,一路向前。
diff --git a/multimodal/YOLOX/demo/MegEngine/cpp/README.md b/multimodal/YOLOX/demo/MegEngine/cpp/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c639f27d96b731bbe58cd99f1347fdb9d11e2cb6
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/cpp/README.md
@@ -0,0 +1,173 @@
+# YOLOX-CPP-MegEngine
+
+Cpp file compile of YOLOX object detection base on [MegEngine](https://github.com/MegEngine/MegEngine).
+
+## Tutorial
+
+### Step1: install toolchain
+
+	* host: sudo apt install gcc/g++ (gcc/g++, which version >= 6) build-essential git git-lfs gfortran libgfortran-6-dev autoconf gnupg flex bison gperf curl zlib1g-dev gcc-multilib g++-multilib cmake
+ * cross build android: download [NDK](https://developer.android.com/ndk/downloads)
+   	* after unzip download NDK, then export NDK_ROOT="path of NDK"
+
+### Step2: build MegEngine
+
+```shell
+git clone https://github.com/MegEngine/MegEngine.git
+
+# then init third_party
+ 
+export megengine_root="path of MegEngine"
+cd $megengine_root && ./third_party/prepare.sh && ./third_party/install-mkl.sh
+
+# build example:
+# build host without cuda:   
+./scripts/cmake-build/host_build.sh
+# or build host with cuda:
+./scripts/cmake-build/host_build.sh -c
+# or cross build for android aarch64: 
+./scripts/cmake-build/cross_build_android_arm_inference.sh
+# or cross build for android aarch64(with V8.2+fp16): 
+./scripts/cmake-build/cross_build_android_arm_inference.sh -f
+
+# after build MegEngine, you need export the `MGE_INSTALL_PATH`
+# host without cuda: 
+export MGE_INSTALL_PATH=${megengine_root}/build_dir/host/MGE_WITH_CUDA_OFF/MGE_INFERENCE_ONLY_ON/Release/install
+# or host with cuda: 
+export MGE_INSTALL_PATH=${megengine_root}/build_dir/host/MGE_WITH_CUDA_ON/MGE_INFERENCE_ONLY_ON/Release/install
+# or cross build for android aarch64: 
+export MGE_INSTALL_PATH=${megengine_root}/build_dir/android/arm64-v8a/Release/install
+```
+* you can refs [build tutorial of MegEngine](https://github.com/MegEngine/MegEngine/blob/master/scripts/cmake-build/BUILD_README.md) to build other platform, eg, windows/macos/ etc!
+
+### Step3: build OpenCV
+
+```shell
+git clone https://github.com/opencv/opencv.git
+
+git checkout 3.4.15 (we test at 3.4.15, if test other version, may need modify some build)
+```
+
+- patch diff for android:
+
+```
+# ```
+#     diff --git a/CMakeLists.txt b/CMakeLists.txt
+#     index f6a2da5310..10354312c9 100644
+#     --- a/CMakeLists.txt
+#     +++ b/CMakeLists.txt
+#     @@ -643,7 +643,7 @@ if(UNIX)
+#        if(NOT APPLE)
+#          CHECK_INCLUDE_FILE(pthread.h HAVE_PTHREAD)
+#          if(ANDROID)
+#     -      set(OPENCV_LINKER_LIBS ${OPENCV_LINKER_LIBS} dl m log)
+#     +      set(OPENCV_LINKER_LIBS ${OPENCV_LINKER_LIBS} dl m log z)
+#          elseif(CMAKE_SYSTEM_NAME MATCHES "FreeBSD|NetBSD|DragonFly|OpenBSD|Haiku")
+#            set(OPENCV_LINKER_LIBS ${OPENCV_LINKER_LIBS} m pthread)
+#          elseif(EMSCRIPTEN)
+    
+# ```
+```
+
+- build for host
+
+```shell
+cd root_dir_of_opencv
+mkdir -p build/install
+cd build
+cmake -DBUILD_JAVA=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_INSTALL_PREFIX=$PWD/install 
+make install -j32
+```
+
+* build for android-aarch64
+
+```shell
+cd root_dir_of_opencv
+mkdir -p build_android/install
+cd build_android
+
+cmake -DCMAKE_TOOLCHAIN_FILE="$NDK_ROOT/build/cmake/android.toolchain.cmake" -DANDROID_NDK="$NDK_ROOT"  -DANDROID_ABI=arm64-v8a -DANDROID_NATIVE_API_LEVEL=21 -DBUILD_JAVA=OFF -DBUILD_ANDROID_PROJECTS=OFF -DBUILD_ANDROID_EXAMPLES=OFF -DBUILD_SHARED_LIBS=ON -DCMAKE_INSTALL_PREFIX=$PWD/install ..
+
+make install -j32
+```
+
+* after build OpenCV, you need export  `OPENCV_INSTALL_INCLUDE_PATH ` and `OPENCV_INSTALL_LIB_PATH`
+
+```shell
+# host build: 
+export OPENCV_INSTALL_INCLUDE_PATH=${path of opencv}/build/install/include
+export OPENCV_INSTALL_LIB_PATH=${path of opencv}/build/install/lib
+# or cross build for android aarch64:
+export OPENCV_INSTALL_INCLUDE_PATH=${path of opencv}/build_android/install/sdk/native/jni/include
+export OPENCV_INSTALL_LIB_PATH=${path of opencv}/build_android/install/sdk/native/libs/arm64-v8a
+```
+
+###  Step4: build test demo
+
+```shell
+run build.sh
+
+# if host:
+export CXX=g++
+./build.sh
+# or cross android aarch64
+export CXX=aarch64-linux-android21-clang++
+./build.sh
+```
+
+### Step5: run demo
+
+> **Note**: two ways to get `yolox_s.mge` model file
+>
+> * reference to python demo's `dump.py` script.
+> * For users with code before 0.1.0 version, wget yolox-s weights [here](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_s.mge).
+> * For users with code after 0.1.0 version, use [python code in megengine](../python) to generate mge file.
+
+```shell
+# if host:
+LD_LIBRARY_PATH=$MGE_INSTALL_PATH/lib/:$OPENCV_INSTALL_LIB_PATH ./yolox yolox_s.mge ../../../assets/dog.jpg cuda/cpu/multithread <warmup_count> <thread_number>
+
+# or cross android
+adb push/scp $MGE_INSTALL_PATH/lib/libmegengine.so android_phone
+adb push/scp $OPENCV_INSTALL_LIB_PATH/*.so android_phone
+adb push/scp ./yolox yolox_s.mge android_phone
+adb push/scp ../../../assets/dog.jpg android_phone
+
+# login in android_phone by adb or ssh
+# then run: 
+LD_LIBRARY_PATH=. ./yolox yolox_s.mge dog.jpg cpu/multithread <warmup_count> <thread_number> <use_fast_run> <use_weight_preprocess>  <run_with_fp16>
+
+# * <warmup_count> means warmup count, valid number >=0
+# * <thread_number> means thread number, valid number >=1, only take effect `multithread` device
+# * <use_fast_run> if >=1 , will use fastrun to choose best algo
+# * <use_weight_preprocess> if >=1, will handle weight preprocess before exe
+# * <run_with_fp16> if >=1, will run with fp16 mode
+```
+
+## Bechmark
+
+* model info: yolox-s @ input(1,3,640,640)					
+
+* test devices
+
+```
+  * x86_64  -- Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz					
+  * aarch64 -- xiamo phone mi9					
+  * cuda    -- 1080TI @ cuda-10.1-cudnn-v7.6.3-TensorRT-6.0.1.5.sh @ Intel(R) Xeon(R) CPU E5-2620 v4 @ 2.10GHz
+```
+
+  | megengine @ tag1.4(fastrun + weight\_preprocess)/sec | 1 thread |
+  | ---------------------------------------------------- | -------- |
+  | x86\_64                                              | 0.516245 |
+  | aarch64(fp32+chw44)                                  | 0.587857 |
+
+  | CUDA @ 1080TI/sec   | 1 batch    | 2 batch   | 4 batch   | 8 batch   | 16 batch  | 32 batch | 64 batch |
+  | ------------------- | ---------- | --------- | --------- | --------- | --------- | -------- | -------- |
+  | megengine(fp32+chw) | 0.00813703 | 0.0132893 | 0.0236633 | 0.0444699 | 0.0864917 | 0.16895  | 0.334248 |
+
+## Acknowledgement
+
+* [MegEngine](https://github.com/MegEngine/MegEngine)
+* [OpenCV](https://github.com/opencv/opencv)
+* [NDK](https://developer.android.com/ndk)
+* [CMAKE](https://cmake.org/)
diff --git a/multimodal/YOLOX/demo/MegEngine/cpp/build.sh b/multimodal/YOLOX/demo/MegEngine/cpp/build.sh
new file mode 100755
index 0000000000000000000000000000000000000000..0954305ab4ee9c76c68567c0ed851749049f5bab
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/cpp/build.sh
@@ -0,0 +1,61 @@
+#!/usr/bin/env bash
+set -e
+
+if [ -z $CXX ];then
+    echo "please export you c++ toolchain to CXX"
+    echo "for example:"
+    echo "build for host:                                        export CXX=g++"
+    echo "cross build for aarch64-android(always locate in NDK): export CXX=aarch64-linux-android21-clang++"
+    echo "cross build for aarch64-linux:                         export CXX=aarch64-linux-gnu-g++"
+    exit -1
+fi
+
+if [ -z $MGE_INSTALL_PATH ];then
+    echo "please refsi ./README.md to init MGE_INSTALL_PATH env"
+    exit -1
+fi
+
+if [ -z $OPENCV_INSTALL_INCLUDE_PATH ];then
+    echo "please refs ./README.md to init OPENCV_INSTALL_INCLUDE_PATH env"
+    exit -1
+fi
+
+if [ -z $OPENCV_INSTALL_LIB_PATH ];then
+    echo "please refs ./README.md to init OPENCV_INSTALL_LIB_PATH env"
+    exit -1
+fi
+
+INCLUDE_FLAG="-I$MGE_INSTALL_PATH/include -I$OPENCV_INSTALL_INCLUDE_PATH"
+LINK_FLAG="-L$MGE_INSTALL_PATH/lib/ -lmegengine -L$OPENCV_INSTALL_LIB_PATH -lopencv_core -lopencv_highgui -lopencv_imgproc -lopencv_imgcodecs"
+BUILD_FLAG="-static-libstdc++ -O3 -pie -fPIE -g"
+
+if [[ $CXX =~ "android" ]]; then
+    LINK_FLAG="${LINK_FLAG} -llog -lz"
+fi
+
+echo "CXX: $CXX"
+echo "MGE_INSTALL_PATH: $MGE_INSTALL_PATH"
+echo "INCLUDE_FLAG: $INCLUDE_FLAG"
+echo "LINK_FLAG: $LINK_FLAG"
+echo "BUILD_FLAG: $BUILD_FLAG"
+
+echo "[" > compile_commands.json
+echo "{" >> compile_commands.json
+echo "\"directory\": \"$PWD\"," >> compile_commands.json
+echo "\"command\": \"$CXX yolox.cpp -o yolox ${INCLUDE_FLAG} ${LINK_FLAG}\"," >> compile_commands.json
+echo "\"file\": \"$PWD/yolox.cpp\"," >> compile_commands.json
+echo "}," >> compile_commands.json
+echo "]" >> compile_commands.json
+$CXX yolox.cpp -o yolox ${INCLUDE_FLAG} ${LINK_FLAG} ${BUILD_FLAG}
+
+echo "build success, output file: yolox"
+if [[ $CXX =~ "android" ]]; then
+    echo "try command to run:"
+    echo "adb push/scp $MGE_INSTALL_PATH/lib/libmegengine.so android_phone"
+    echo "adb push/scp $OPENCV_INSTALL_LIB_PATH/*.so android_phone"
+    echo "adb push/scp ./yolox yolox_s.mge android_phone"
+    echo "adb push/scp ../../../assets/dog.jpg android_phone"
+    echo "adb/ssh to android_phone, then run: LD_LIBRARY_PATH=. ./yolox yolox_s.mge dog.jpg cpu/multithread <warmup_count> <thread_number> <use_fast_run> <use_weight_preprocess>"
+else
+    echo "try command to run: LD_LIBRARY_PATH=$MGE_INSTALL_PATH/lib/:$OPENCV_INSTALL_LIB_PATH ./yolox yolox_s.mge ../../../assets/dog.jpg cuda/cpu/multithread <warmup_count> <thread_number> <use_fast_run> <use_weight_preprocess>"
+fi
diff --git a/multimodal/YOLOX/demo/MegEngine/cpp/yolox.cpp b/multimodal/YOLOX/demo/MegEngine/cpp/yolox.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..859e6dcd2d1c42165a901d07a20e60ac256da912
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/cpp/yolox.cpp
@@ -0,0 +1,470 @@
+// Copyright (C) 2018-2021 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+
+#include "megbrain/gopt/inference.h"
+#include "megbrain/opr/search_policy/algo_chooser_helper.h"
+#include "megbrain/serialization/serializer.h"
+#include <iostream>
+#include <iterator>
+#include <memory>
+#include <opencv2/opencv.hpp>
+#include <stdlib.h>
+#include <string>
+#include <vector>
+
+/**
+ * @brief Define names based depends on Unicode path support
+ */
+#define NMS_THRESH 0.45
+#define BBOX_CONF_THRESH 0.25
+
+constexpr int INPUT_W = 640;
+constexpr int INPUT_H = 640;
+
+using namespace mgb;
+
+cv::Mat static_resize(cv::Mat &img) {
+  float r = std::min(INPUT_W / (img.cols * 1.0), INPUT_H / (img.rows * 1.0));
+  int unpad_w = r * img.cols;
+  int unpad_h = r * img.rows;
+  cv::Mat re(unpad_h, unpad_w, CV_8UC3);
+  cv::resize(img, re, re.size());
+  cv::Mat out(INPUT_W, INPUT_H, CV_8UC3, cv::Scalar(114, 114, 114));
+  re.copyTo(out(cv::Rect(0, 0, re.cols, re.rows)));
+  return out;
+}
+
+void blobFromImage(cv::Mat &img, float *blob_data) {
+  int channels = 3;
+  int img_h = img.rows;
+  int img_w = img.cols;
+  for (size_t c = 0; c < channels; c++) {
+    for (size_t h = 0; h < img_h; h++) {
+      for (size_t w = 0; w < img_w; w++) {
+        blob_data[c * img_w * img_h + h * img_w + w] =
+            (float)img.at<cv::Vec3b>(h, w)[c];
+      }
+    }
+  }
+}
+
+struct Object {
+  cv::Rect_<float> rect;
+  int label;
+  float prob;
+};
+
+struct GridAndStride {
+  int grid0;
+  int grid1;
+  int stride;
+};
+
+static void
+generate_grids_and_stride(const int target_size, std::vector<int> &strides,
+                          std::vector<GridAndStride> &grid_strides) {
+  for (auto stride : strides) {
+    int num_grid = target_size / stride;
+    for (int g1 = 0; g1 < num_grid; g1++) {
+      for (int g0 = 0; g0 < num_grid; g0++) {
+        grid_strides.push_back((GridAndStride){g0, g1, stride});
+      }
+    }
+  }
+}
+
+static void generate_yolox_proposals(std::vector<GridAndStride> grid_strides,
+                                     const float *feat_ptr,
+                                     float prob_threshold,
+                                     std::vector<Object> &objects) {
+  const int num_class = 80;
+  const int num_anchors = grid_strides.size();
+
+  for (int anchor_idx = 0; anchor_idx < num_anchors; anchor_idx++) {
+    const int grid0 = grid_strides[anchor_idx].grid0;
+    const int grid1 = grid_strides[anchor_idx].grid1;
+    const int stride = grid_strides[anchor_idx].stride;
+
+    const int basic_pos = anchor_idx * 85;
+
+    float x_center = (feat_ptr[basic_pos + 0] + grid0) * stride;
+    float y_center = (feat_ptr[basic_pos + 1] + grid1) * stride;
+    float w = exp(feat_ptr[basic_pos + 2]) * stride;
+    float h = exp(feat_ptr[basic_pos + 3]) * stride;
+    float x0 = x_center - w * 0.5f;
+    float y0 = y_center - h * 0.5f;
+
+    float box_objectness = feat_ptr[basic_pos + 4];
+    for (int class_idx = 0; class_idx < num_class; class_idx++) {
+      float box_cls_score = feat_ptr[basic_pos + 5 + class_idx];
+      float box_prob = box_objectness * box_cls_score;
+      if (box_prob > prob_threshold) {
+        Object obj;
+        obj.rect.x = x0;
+        obj.rect.y = y0;
+        obj.rect.width = w;
+        obj.rect.height = h;
+        obj.label = class_idx;
+        obj.prob = box_prob;
+
+        objects.push_back(obj);
+      }
+
+    } // class loop
+
+  } // point anchor loop
+}
+
+static inline float intersection_area(const Object &a, const Object &b) {
+  cv::Rect_<float> inter = a.rect & b.rect;
+  return inter.area();
+}
+
+static void qsort_descent_inplace(std::vector<Object> &faceobjects, int left,
+                                  int right) {
+  int i = left;
+  int j = right;
+  float p = faceobjects[(left + right) / 2].prob;
+
+  while (i <= j) {
+    while (faceobjects[i].prob > p)
+      i++;
+
+    while (faceobjects[j].prob < p)
+      j--;
+
+    if (i <= j) {
+      // swap
+      std::swap(faceobjects[i], faceobjects[j]);
+
+      i++;
+      j--;
+    }
+  }
+
+#pragma omp parallel sections
+  {
+#pragma omp section
+    {
+      if (left < j)
+        qsort_descent_inplace(faceobjects, left, j);
+    }
+#pragma omp section
+    {
+      if (i < right)
+        qsort_descent_inplace(faceobjects, i, right);
+    }
+  }
+}
+
+static void qsort_descent_inplace(std::vector<Object> &objects) {
+  if (objects.empty())
+    return;
+
+  qsort_descent_inplace(objects, 0, objects.size() - 1);
+}
+
+static void nms_sorted_bboxes(const std::vector<Object> &faceobjects,
+                              std::vector<int> &picked, float nms_threshold) {
+  picked.clear();
+
+  const int n = faceobjects.size();
+
+  std::vector<float> areas(n);
+  for (int i = 0; i < n; i++) {
+    areas[i] = faceobjects[i].rect.area();
+  }
+
+  for (int i = 0; i < n; i++) {
+    const Object &a = faceobjects[i];
+
+    int keep = 1;
+    for (int j = 0; j < (int)picked.size(); j++) {
+      const Object &b = faceobjects[picked[j]];
+
+      // intersection over union
+      float inter_area = intersection_area(a, b);
+      float union_area = areas[i] + areas[picked[j]] - inter_area;
+      // float IoU = inter_area / union_area
+      if (inter_area / union_area > nms_threshold)
+        keep = 0;
+    }
+
+    if (keep)
+      picked.push_back(i);
+  }
+}
+
+static void decode_outputs(const float *prob, std::vector<Object> &objects,
+                           float scale, const int img_w, const int img_h) {
+  std::vector<Object> proposals;
+  std::vector<int> strides = {8, 16, 32};
+  std::vector<GridAndStride> grid_strides;
+
+  generate_grids_and_stride(INPUT_W, strides, grid_strides);
+  generate_yolox_proposals(grid_strides, prob, BBOX_CONF_THRESH, proposals);
+  qsort_descent_inplace(proposals);
+
+  std::vector<int> picked;
+  nms_sorted_bboxes(proposals, picked, NMS_THRESH);
+  int count = picked.size();
+  objects.resize(count);
+
+  for (int i = 0; i < count; i++) {
+    objects[i] = proposals[picked[i]];
+
+    // adjust offset to original unpadded
+    float x0 = (objects[i].rect.x) / scale;
+    float y0 = (objects[i].rect.y) / scale;
+    float x1 = (objects[i].rect.x + objects[i].rect.width) / scale;
+    float y1 = (objects[i].rect.y + objects[i].rect.height) / scale;
+
+    // clip
+    x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
+    y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
+    x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
+    y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);
+
+    objects[i].rect.x = x0;
+    objects[i].rect.y = y0;
+    objects[i].rect.width = x1 - x0;
+    objects[i].rect.height = y1 - y0;
+  }
+}
+
+const float color_list[80][3] = {
+    {0.000, 0.447, 0.741}, {0.850, 0.325, 0.098}, {0.929, 0.694, 0.125},
+    {0.494, 0.184, 0.556}, {0.466, 0.674, 0.188}, {0.301, 0.745, 0.933},
+    {0.635, 0.078, 0.184}, {0.300, 0.300, 0.300}, {0.600, 0.600, 0.600},
+    {1.000, 0.000, 0.000}, {1.000, 0.500, 0.000}, {0.749, 0.749, 0.000},
+    {0.000, 1.000, 0.000}, {0.000, 0.000, 1.000}, {0.667, 0.000, 1.000},
+    {0.333, 0.333, 0.000}, {0.333, 0.667, 0.000}, {0.333, 1.000, 0.000},
+    {0.667, 0.333, 0.000}, {0.667, 0.667, 0.000}, {0.667, 1.000, 0.000},
+    {1.000, 0.333, 0.000}, {1.000, 0.667, 0.000}, {1.000, 1.000, 0.000},
+    {0.000, 0.333, 0.500}, {0.000, 0.667, 0.500}, {0.000, 1.000, 0.500},
+    {0.333, 0.000, 0.500}, {0.333, 0.333, 0.500}, {0.333, 0.667, 0.500},
+    {0.333, 1.000, 0.500}, {0.667, 0.000, 0.500}, {0.667, 0.333, 0.500},
+    {0.667, 0.667, 0.500}, {0.667, 1.000, 0.500}, {1.000, 0.000, 0.500},
+    {1.000, 0.333, 0.500}, {1.000, 0.667, 0.500}, {1.000, 1.000, 0.500},
+    {0.000, 0.333, 1.000}, {0.000, 0.667, 1.000}, {0.000, 1.000, 1.000},
+    {0.333, 0.000, 1.000}, {0.333, 0.333, 1.000}, {0.333, 0.667, 1.000},
+    {0.333, 1.000, 1.000}, {0.667, 0.000, 1.000}, {0.667, 0.333, 1.000},
+    {0.667, 0.667, 1.000}, {0.667, 1.000, 1.000}, {1.000, 0.000, 1.000},
+    {1.000, 0.333, 1.000}, {1.000, 0.667, 1.000}, {0.333, 0.000, 0.000},
+    {0.500, 0.000, 0.000}, {0.667, 0.000, 0.000}, {0.833, 0.000, 0.000},
+    {1.000, 0.000, 0.000}, {0.000, 0.167, 0.000}, {0.000, 0.333, 0.000},
+    {0.000, 0.500, 0.000}, {0.000, 0.667, 0.000}, {0.000, 0.833, 0.000},
+    {0.000, 1.000, 0.000}, {0.000, 0.000, 0.167}, {0.000, 0.000, 0.333},
+    {0.000, 0.000, 0.500}, {0.000, 0.000, 0.667}, {0.000, 0.000, 0.833},
+    {0.000, 0.000, 1.000}, {0.000, 0.000, 0.000}, {0.143, 0.143, 0.143},
+    {0.286, 0.286, 0.286}, {0.429, 0.429, 0.429}, {0.571, 0.571, 0.571},
+    {0.714, 0.714, 0.714}, {0.857, 0.857, 0.857}, {0.000, 0.447, 0.741},
+    {0.314, 0.717, 0.741}, {0.50, 0.5, 0}};
+
+static void draw_objects(const cv::Mat &bgr,
+                         const std::vector<Object> &objects) {
+  static const char *class_names[] = {
+      "person",        "bicycle",      "car",
+      "motorcycle",    "airplane",     "bus",
+      "train",         "truck",        "boat",
+      "traffic light", "fire hydrant", "stop sign",
+      "parking meter", "bench",        "bird",
+      "cat",           "dog",          "horse",
+      "sheep",         "cow",          "elephant",
+      "bear",          "zebra",        "giraffe",
+      "backpack",      "umbrella",     "handbag",
+      "tie",           "suitcase",     "frisbee",
+      "skis",          "snowboard",    "sports ball",
+      "kite",          "baseball bat", "baseball glove",
+      "skateboard",    "surfboard",    "tennis racket",
+      "bottle",        "wine glass",   "cup",
+      "fork",          "knife",        "spoon",
+      "bowl",          "banana",       "apple",
+      "sandwich",      "orange",       "broccoli",
+      "carrot",        "hot dog",      "pizza",
+      "donut",         "cake",         "chair",
+      "couch",         "potted plant", "bed",
+      "dining table",  "toilet",       "tv",
+      "laptop",        "mouse",        "remote",
+      "keyboard",      "cell phone",   "microwave",
+      "oven",          "toaster",      "sink",
+      "refrigerator",  "book",         "clock",
+      "vase",          "scissors",     "teddy bear",
+      "hair drier",    "toothbrush"};
+
+  cv::Mat image = bgr.clone();
+
+  for (size_t i = 0; i < objects.size(); i++) {
+    const Object &obj = objects[i];
+
+    fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
+            obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
+
+    cv::Scalar color =
+        cv::Scalar(color_list[obj.label][0], color_list[obj.label][1],
+                   color_list[obj.label][2]);
+    float c_mean = cv::mean(color)[0];
+    cv::Scalar txt_color;
+    if (c_mean > 0.5) {
+      txt_color = cv::Scalar(0, 0, 0);
+    } else {
+      txt_color = cv::Scalar(255, 255, 255);
+    }
+
+    cv::rectangle(image, obj.rect, color * 255, 2);
+
+    char text[256];
+    sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
+
+    int baseLine = 0;
+    cv::Size label_size =
+        cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
+
+    cv::Scalar txt_bk_color = color * 0.7 * 255;
+
+    int x = obj.rect.x;
+    int y = obj.rect.y + 1;
+    // int y = obj.rect.y - label_size.height - baseLine;
+    if (y > image.rows)
+      y = image.rows;
+    // if (x + label_size.width > image.cols)
+    // x = image.cols - label_size.width;
+
+    cv::rectangle(
+        image,
+        cv::Rect(cv::Point(x, y),
+                 cv::Size(label_size.width, label_size.height + baseLine)),
+        txt_bk_color, -1);
+
+    cv::putText(image, text, cv::Point(x, y + label_size.height),
+                cv::FONT_HERSHEY_SIMPLEX, 0.4, txt_color, 1);
+  }
+
+  cv::imwrite("out.jpg", image);
+  std::cout << "save output to out.jpg" << std::endl;
+}
+
+cg::ComputingGraph::OutputSpecItem make_callback_copy(SymbolVar dev,
+                                                      HostTensorND &host) {
+  auto cb = [&host](DeviceTensorND &d) { host.copy_from(d); };
+  return {dev, cb};
+}
+
+int main(int argc, char *argv[]) {
+  serialization::GraphLoader::LoadConfig load_config;
+  load_config.comp_graph = ComputingGraph::make();
+  auto &&graph_opt = load_config.comp_graph->options();
+  graph_opt.graph_opt_level = 0;
+
+  if (argc != 9) {
+    std::cout << "Usage : " << argv[0]
+              << " <path_to_model> <path_to_image> <device> <warmup_count> "
+                 "<thread_number> <use_fast_run> <use_weight_preprocess> "
+                 "<run_with_fp16>"
+              << std::endl;
+    return EXIT_FAILURE;
+  }
+
+  const std::string input_model{argv[1]};
+  const std::string input_image_path{argv[2]};
+  const std::string device{argv[3]};
+  const size_t warmup_count = atoi(argv[4]);
+  const size_t thread_number = atoi(argv[5]);
+  const size_t use_fast_run = atoi(argv[6]);
+  const size_t use_weight_preprocess = atoi(argv[7]);
+  const size_t run_with_fp16 = atoi(argv[8]);
+
+  if (device == "cuda") {
+    load_config.comp_node_mapper = [](CompNode::Locator &loc) {
+      loc.type = CompNode::DeviceType::CUDA;
+    };
+  } else if (device == "cpu") {
+    load_config.comp_node_mapper = [](CompNode::Locator &loc) {
+      loc.type = CompNode::DeviceType::CPU;
+    };
+  } else if (device == "multithread") {
+    load_config.comp_node_mapper = [thread_number](CompNode::Locator &loc) {
+      loc.type = CompNode::DeviceType::MULTITHREAD;
+      loc.device = 0;
+      loc.stream = thread_number;
+    };
+    std::cout << "use " << thread_number << " thread" << std::endl;
+  } else {
+    std::cout << "device only support cuda or cpu or multithread" << std::endl;
+    return EXIT_FAILURE;
+  }
+
+  if (use_weight_preprocess) {
+    std::cout << "use weight preprocess" << std::endl;
+    graph_opt.graph_opt.enable_weight_preprocess();
+  }
+  if (run_with_fp16) {
+    std::cout << "run with fp16" << std::endl;
+    graph_opt.graph_opt.enable_f16_io_comp();
+  }
+
+  if (device == "cuda") {
+    std::cout << "choose format for cuda" << std::endl;
+  } else {
+    std::cout << "choose format for non-cuda" << std::endl;
+#if defined(__arm__) || defined(__aarch64__)
+    if (run_with_fp16) {
+      std::cout << "use chw format when enable fp16" << std::endl;
+    } else {
+      std::cout << "choose format for nchw44 for aarch64" << std::endl;
+      graph_opt.graph_opt.enable_nchw44();
+    }
+#endif
+#if defined(__x86_64__) || defined(__amd64__) || defined(__i386__)
+    // graph_opt.graph_opt.enable_nchw88();
+#endif
+  }
+
+  std::unique_ptr<serialization::InputFile> inp_file =
+      serialization::InputFile::make_fs(input_model.c_str());
+  auto loader = serialization::GraphLoader::make(std::move(inp_file));
+  serialization::GraphLoader::LoadResult network =
+      loader->load(load_config, false);
+
+  if (use_fast_run) {
+    std::cout << "use fastrun" << std::endl;
+    using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
+    S strategy = static_cast<S>(0);
+    strategy = S::PROFILE | S::OPTIMIZED | strategy;
+    mgb::gopt::modify_opr_algo_strategy_inplace(network.output_var_list,
+                                                strategy);
+  }
+
+  auto data = network.tensor_map["data"];
+  cv::Mat image = cv::imread(input_image_path);
+  cv::Mat pr_img = static_resize(image);
+  float *data_ptr = data->resize({1, 3, 640, 640}).ptr<float>();
+  blobFromImage(pr_img, data_ptr);
+  HostTensorND predict;
+  std::unique_ptr<cg::AsyncExecutable> func = network.graph->compile(
+      {make_callback_copy(network.output_var_map.begin()->second, predict)});
+
+  for (auto i = 0; i < warmup_count; i++) {
+    std::cout << "warmup: " << i << std::endl;
+    func->execute();
+    func->wait();
+  }
+  auto start = std::chrono::system_clock::now();
+  func->execute();
+  func->wait();
+  auto end = std::chrono::system_clock::now();
+  std::chrono::duration<double> exec_seconds = end - start;
+  std::cout << "elapsed time: " << exec_seconds.count() << "s" << std::endl;
+
+  float *predict_ptr = predict.ptr<float>();
+  int img_w = image.cols;
+  int img_h = image.rows;
+  float scale =
+      std::min(INPUT_W / (image.cols * 1.0), INPUT_H / (image.rows * 1.0));
+  std::vector<Object> objects;
+
+  decode_outputs(predict_ptr, objects, scale, img_w, img_h);
+  draw_objects(image, objects);
+
+  return EXIT_SUCCESS;
+}
diff --git a/multimodal/YOLOX/demo/MegEngine/python/README.md b/multimodal/YOLOX/demo/MegEngine/python/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..97ec25563229fcc2914deb80c1135cda8d49bfb2
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/README.md
@@ -0,0 +1,33 @@
+# YOLOX-Python-MegEngine
+
+Python version of YOLOX object detection base on [MegEngine](https://github.com/MegEngine/MegEngine).
+
+## Tutorial
+
+### Step1: install requirements
+
+```
+python3 -m pip install megengine -f https://megengine.org.cn/whl/mge.html
+```
+
+### Step2: convert checkpoint weights from torch's path file
+
+```
+python3 convert_weights.py -w yolox_s.pth -o yolox_s_mge.pkl
+```
+
+### Step3: run demo
+
+This part is the same as torch's python demo, but no need to specify device.
+
+```
+python3 demo.py image -n yolox-s -c yolox_s_mge.pkl --path ../../../assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result
+```
+
+###  [Optional]Step4: dump model for cpp inference
+
+> **Note**: result model is dumped with `optimize_for_inference` and `enable_fuse_conv_bias_nonlinearity`.
+
+```
+python3 dump.py -n yolox-s -c yolox_s_mge.pkl --dump_path yolox_s.mge
+```
diff --git a/multimodal/YOLOX/demo/MegEngine/python/build.py b/multimodal/YOLOX/demo/MegEngine/python/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..139f4e7c7302e6ad9c3ae09c2918599b2b192a03
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/build.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import megengine as mge
+import megengine.module as M
+
+from models.yolo_fpn import YOLOFPN
+from models.yolo_head import YOLOXHead
+from models.yolo_pafpn import YOLOPAFPN
+from models.yolox import YOLOX
+
+
+def build_yolox(name="yolox-s"):
+    num_classes = 80
+
+    # value meaning: depth, width
+    param_dict = {
+        "yolox-nano": (0.33, 0.25),
+        "yolox-tiny": (0.33, 0.375),
+        "yolox-s": (0.33, 0.50),
+        "yolox-m": (0.67, 0.75),
+        "yolox-l": (1.0, 1.0),
+        "yolox-x": (1.33, 1.25),
+    }
+    if name == "yolov3":
+        depth = 1.0
+        width = 1.0
+        backbone = YOLOFPN()
+        head = YOLOXHead(num_classes, width, in_channels=[128, 256, 512], act="lrelu")
+        model = YOLOX(backbone, head)
+    else:
+        assert name in param_dict
+        kwargs = {}
+        depth, width = param_dict[name]
+        if name == "yolox-nano":
+            kwargs["depthwise"] = True
+        in_channels = [256, 512, 1024]
+        backbone = YOLOPAFPN(depth, width, in_channels=in_channels, **kwargs)
+        head = YOLOXHead(num_classes, width, in_channels=in_channels, **kwargs)
+        model = YOLOX(backbone, head)
+
+    for m in model.modules():
+        if isinstance(m, M.BatchNorm2d):
+            m.eps = 1e-3
+
+    return model
+
+
+def build_and_load(weight_file, name="yolox-s"):
+    model = build_yolox(name)
+    model_weights = mge.load(weight_file)
+    model.load_state_dict(model_weights, strict=False)
+    return model
diff --git a/multimodal/YOLOX/demo/MegEngine/python/convert_weights.py b/multimodal/YOLOX/demo/MegEngine/python/convert_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..198caeeb38efe5400323828e4c0e91ba94a99167
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/convert_weights.py
@@ -0,0 +1,64 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+import argparse
+from collections import OrderedDict
+
+import megengine as mge
+import torch
+
+
+def make_parser():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("-w", "--weights", type=str, help="path of weight file")
+    parser.add_argument(
+        "-o",
+        "--output",
+        default="weight_mge.pkl",
+        type=str,
+        help="path of weight file",
+    )
+    return parser
+
+
+def numpy_weights(weight_file):
+    torch_weights = torch.load(weight_file, map_location="cpu")
+    if "model" in torch_weights:
+        torch_weights = torch_weights["model"]
+    new_dict = OrderedDict()
+    for k, v in torch_weights.items():
+        new_dict[k] = v.cpu().numpy()
+    return new_dict
+
+
+def map_weights(weight_file, output_file):
+    torch_weights = numpy_weights(weight_file)
+
+    new_dict = OrderedDict()
+    for k, v in torch_weights.items():
+        if "num_batches_tracked" in k:
+            print("drop: {}".format(k))
+            continue
+        if k.endswith("bias"):
+            print("bias key: {}".format(k))
+            v = v.reshape(1, -1, 1, 1)
+            new_dict[k] = v
+        elif "dconv" in k and "conv.weight" in k:
+            print("depthwise conv key: {}".format(k))
+            cout, cin, k1, k2 = v.shape
+            v = v.reshape(cout, 1, cin, k1, k2)
+            new_dict[k] = v
+        else:
+            new_dict[k] = v
+
+    mge.save(new_dict, output_file)
+    print("save weights to {}".format(output_file))
+
+
+def main():
+    parser = make_parser()
+    args = parser.parse_args()
+    map_weights(args.weights, args.output)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/YOLOX/demo/MegEngine/python/demo.py b/multimodal/YOLOX/demo/MegEngine/python/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..6542853a1a0eb1f8882892fcf55fff8838bd1468
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/demo.py
@@ -0,0 +1,237 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+import time
+
+import cv2
+import megengine as mge
+import megengine.functional as F
+from loguru import logger
+
+from yolox.data.datasets import COCO_CLASSES
+from yolox.utils import vis
+from yolox.data.data_augment import preproc as preprocess
+
+from build import build_and_load
+
+IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX Demo!")
+    parser.add_argument(
+        "demo", default="image", help="demo type, eg. image, video and webcam"
+    )
+    parser.add_argument("-n", "--name", type=str, default="yolox-s", help="model name")
+    parser.add_argument("--path", default="./test.png", help="path to images or video")
+    parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
+    parser.add_argument(
+        "--save_result",
+        action="store_true",
+        help="whether to save the inference result of image/video",
+    )
+
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
+    parser.add_argument("--conf", default=None, type=float, help="test conf")
+    parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
+    parser.add_argument("--tsize", default=None, type=int, help="test img size")
+    return parser
+
+
+def get_image_list(path):
+    image_names = []
+    for maindir, subdir, file_name_list in os.walk(path):
+        for filename in file_name_list:
+            apath = os.path.join(maindir, filename)
+            ext = os.path.splitext(apath)[1]
+            if ext in IMAGE_EXT:
+                image_names.append(apath)
+    return image_names
+
+
+def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
+    box_corner = F.zeros_like(prediction)
+    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
+    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
+    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
+    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
+    prediction[:, :, :4] = box_corner[:, :, :4]
+
+    output = [None for _ in range(len(prediction))]
+    for i, image_pred in enumerate(prediction):
+
+        # If none are remaining => process next image
+        if not image_pred.shape[0]:
+            continue
+        # Get score and class with highest confidence
+        class_conf = F.max(image_pred[:, 5: 5 + num_classes], 1, keepdims=True)
+        class_pred = F.argmax(image_pred[:, 5: 5 + num_classes], 1, keepdims=True)
+
+        class_conf_squeeze = F.squeeze(class_conf)
+        conf_mask = image_pred[:, 4] * class_conf_squeeze >= conf_thre
+        detections = F.concat((image_pred[:, :5], class_conf, class_pred), 1)
+        detections = detections[conf_mask]
+        if not detections.shape[0]:
+            continue
+
+        nms_out_index = F.vision.nms(
+            detections[:, :4], detections[:, 4] * detections[:, 5], nms_thre,
+        )
+        detections = detections[nms_out_index]
+        if output[i] is None:
+            output[i] = detections
+        else:
+            output[i] = F.concat((output[i], detections))
+
+    return output
+
+
+class Predictor(object):
+    def __init__(
+        self,
+        model,
+        confthre=0.01,
+        nmsthre=0.65,
+        test_size=(640, 640),
+        cls_names=COCO_CLASSES,
+        trt_file=None,
+        decoder=None,
+    ):
+        self.model = model
+        self.cls_names = cls_names
+        self.decoder = decoder
+        self.num_classes = 80
+        self.confthre = confthre
+        self.nmsthre = nmsthre
+        self.test_size = test_size
+
+    def inference(self, img):
+        img_info = {"id": 0}
+        if isinstance(img, str):
+            img_info["file_name"] = os.path.basename(img)
+            img = cv2.imread(img)
+            if img is None:
+                raise ValueError("test image path is invalid!")
+        else:
+            img_info["file_name"] = None
+
+        height, width = img.shape[:2]
+        img_info["height"] = height
+        img_info["width"] = width
+        img_info["raw_img"] = img
+
+        img, ratio = preprocess(img, self.test_size)
+        img_info["ratio"] = ratio
+        img = F.expand_dims(mge.tensor(img), 0)
+
+        t0 = time.time()
+        outputs = self.model(img)
+        outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre)
+        logger.info("Infer time: {:.4f}s".format(time.time() - t0))
+        return outputs, img_info
+
+    def visual(self, output, img_info, cls_conf=0.35):
+        ratio = img_info["ratio"]
+        img = img_info["raw_img"]
+        if output is None:
+            return img
+        output = output.numpy()
+
+        # preprocessing: resize
+        bboxes = output[:, 0:4] / ratio
+
+        cls = output[:, 6]
+        scores = output[:, 4] * output[:, 5]
+
+        vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
+        return vis_res
+
+
+def image_demo(predictor, vis_folder, path, current_time, save_result):
+    if os.path.isdir(path):
+        files = get_image_list(path)
+    else:
+        files = [path]
+    files.sort()
+    for image_name in files:
+        outputs, img_info = predictor.inference(image_name)
+        result_image = predictor.visual(outputs[0], img_info)
+        if save_result:
+            save_folder = os.path.join(
+                vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
+            )
+            os.makedirs(save_folder, exist_ok=True)
+            save_file_name = os.path.join(save_folder, os.path.basename(image_name))
+            logger.info("Saving detection result in {}".format(save_file_name))
+            cv2.imwrite(save_file_name, result_image)
+        ch = cv2.waitKey(0)
+        if ch == 27 or ch == ord("q") or ch == ord("Q"):
+            break
+
+
+def imageflow_demo(predictor, vis_folder, current_time, args):
+    cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
+    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
+    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
+    fps = cap.get(cv2.CAP_PROP_FPS)
+    save_folder = os.path.join(
+        vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
+    )
+    os.makedirs(save_folder, exist_ok=True)
+    if args.demo == "video":
+        save_path = os.path.join(save_folder, os.path.basename(args.path))
+    else:
+        save_path = os.path.join(save_folder, "camera.mp4")
+    logger.info(f"video save_path is {save_path}")
+    vid_writer = cv2.VideoWriter(
+        save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
+    )
+    while True:
+        ret_val, frame = cap.read()
+        if ret_val:
+            outputs, img_info = predictor.inference(frame)
+            result_frame = predictor.visual(outputs[0], img_info)
+            if args.save_result:
+                vid_writer.write(result_frame)
+            ch = cv2.waitKey(1)
+            if ch == 27 or ch == ord("q") or ch == ord("Q"):
+                break
+        else:
+            break
+
+
+def main(args):
+    file_name = os.path.join("./yolox_outputs", args.name)
+    os.makedirs(file_name, exist_ok=True)
+
+    if args.save_result:
+        vis_folder = os.path.join(file_name, "vis_res")
+        os.makedirs(vis_folder, exist_ok=True)
+
+    confthre = 0.01
+    nmsthre = 0.65
+    test_size = (640, 640)
+    if args.conf is not None:
+        confthre = args.conf
+    if args.nms is not None:
+        nmsthre = args.nms
+    if args.tsize is not None:
+        test_size = (args.tsize, args.tsize)
+
+    model = build_and_load(args.ckpt, name=args.name)
+    model.eval()
+
+    predictor = Predictor(model, confthre, nmsthre, test_size, COCO_CLASSES, None, None)
+    current_time = time.localtime()
+    if args.demo == "image":
+        image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
+    elif args.demo == "video" or args.demo == "webcam":
+        imageflow_demo(predictor, vis_folder, current_time, args)
+
+
+if __name__ == "__main__":
+    args = make_parser().parse_args()
+    main(args)
diff --git a/multimodal/YOLOX/demo/MegEngine/python/dump.py b/multimodal/YOLOX/demo/MegEngine/python/dump.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ca1215bccb2f450e7cba1d971998531c38366cf
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/dump.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+
+import megengine as mge
+import numpy as np
+from megengine import jit
+
+from build import build_and_load
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX Demo Dump")
+    parser.add_argument("-n", "--name", type=str, default="yolox-s", help="model name")
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
+    parser.add_argument(
+        "--dump_path", default="model.mge", help="path to save the dumped model"
+    )
+    return parser
+
+
+def dump_static_graph(model, graph_name="model.mge"):
+    model.eval()
+    model.head.decode_in_inference = False
+
+    data = mge.Tensor(np.random.random((1, 3, 640, 640)))
+
+    @jit.trace(capture_as_const=True)
+    def pred_func(data):
+        outputs = model(data)
+        return outputs
+
+    pred_func(data)
+    pred_func.dump(
+        graph_name,
+        arg_names=["data"],
+        optimize_for_inference=True,
+        enable_fuse_conv_bias_nonlinearity=True,
+    )
+
+
+def main(args):
+    model = build_and_load(args.ckpt, name=args.name)
+    dump_static_graph(model, args.dump_path)
+
+
+if __name__ == "__main__":
+    args = make_parser().parse_args()
+    main(args)
diff --git a/multimodal/YOLOX/demo/MegEngine/python/models/__init__.py b/multimodal/YOLOX/demo/MegEngine/python/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e55d18e337f0f1630afef4312fb9c7a1cdd293e8
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/models/__init__.py
@@ -0,0 +1,9 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+from .darknet import CSPDarknet, Darknet
+from .yolo_fpn import YOLOFPN
+from .yolo_head import YOLOXHead
+from .yolo_pafpn import YOLOPAFPN
+from .yolox import YOLOX
diff --git a/multimodal/YOLOX/demo/MegEngine/python/models/darknet.py b/multimodal/YOLOX/demo/MegEngine/python/models/darknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..47469aa683a91cdf88091956b71637cae7a97dc3
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/models/darknet.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import megengine.module as M
+
+from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
+
+
+class Darknet(M.Module):
+    # number of blocks from dark2 to dark5.
+    depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
+
+    def __init__(
+        self, depth, in_channels=3, stem_out_channels=32, out_features=("dark3", "dark4", "dark5"),
+    ):
+        """
+        Args:
+            depth (int): depth of darknet used in model, usually use [21, 53] for this param.
+            in_channels (int): number of input channels, for example, use 3 for RGB image.
+            stem_out_channels (int): number of output channels of darknet stem.
+                It decides channels of darknet layer2 to layer5.
+            out_features (Tuple[str]): desired output layer name.
+        """
+        super().__init__()
+        assert out_features, "please provide output features of Darknet"
+        self.out_features = out_features
+        self.stem = M.Sequential(
+            BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
+            *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
+        )
+        in_channels = stem_out_channels * 2  # 64
+
+        num_blocks = Darknet.depth2blocks[depth]
+        # create darknet with `stem_out_channels` and `num_blocks` layers.
+        # to make model structure more clear, we don't use `for` statement in python.
+        self.dark2 = M.Sequential(*self.make_group_layer(in_channels, num_blocks[0], stride=2))
+        in_channels *= 2  # 128
+        self.dark3 = M.Sequential(*self.make_group_layer(in_channels, num_blocks[1], stride=2))
+        in_channels *= 2  # 256
+        self.dark4 = M.Sequential(*self.make_group_layer(in_channels, num_blocks[2], stride=2))
+        in_channels *= 2  # 512
+
+        self.dark5 = M.Sequential(
+            *self.make_group_layer(in_channels, num_blocks[3], stride=2),
+            *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
+        )
+
+    def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
+        "starts with conv layer then has `num_blocks` `ResLayer`"
+        return [
+            BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
+            *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)]
+        ]
+
+    def make_spp_block(self, filters_list, in_filters):
+        m = M.Sequential(
+            *[
+                BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
+                BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
+                SPPBottleneck(
+                    in_channels=filters_list[1],
+                    out_channels=filters_list[0],
+                    activation="lrelu"
+                ),
+                BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
+                BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
+            ]
+        )
+        return m
+
+    def forward(self, x):
+        outputs = {}
+        x = self.stem(x)
+        outputs["stem"] = x
+        x = self.dark2(x)
+        outputs["dark2"] = x
+        x = self.dark3(x)
+        outputs["dark3"] = x
+        x = self.dark4(x)
+        outputs["dark4"] = x
+        x = self.dark5(x)
+        outputs["dark5"] = x
+        return {k: v for k, v in outputs.items() if k in self.out_features}
+
+
+class CSPDarknet(M.Module):
+
+    def __init__(
+        self, dep_mul, wid_mul,
+        out_features=("dark3", "dark4", "dark5"),
+        depthwise=False, act="silu",
+    ):
+        super().__init__()
+        assert out_features, "please provide output features of Darknet"
+        self.out_features = out_features
+        Conv = DWConv if depthwise else BaseConv
+
+        base_channels = int(wid_mul * 64)  # 64
+        base_depth = max(round(dep_mul * 3), 1)  # 3
+
+        # stem
+        self.stem = Focus(3, base_channels, ksize=3, act=act)
+
+        # dark2
+        self.dark2 = M.Sequential(
+            Conv(base_channels, base_channels * 2, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 2, base_channels * 2,
+                n=base_depth, depthwise=depthwise, act=act
+            ),
+        )
+
+        # dark3
+        self.dark3 = M.Sequential(
+            Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 4, base_channels * 4,
+                n=base_depth * 3, depthwise=depthwise, act=act,
+            ),
+        )
+
+        # dark4
+        self.dark4 = M.Sequential(
+            Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 8, base_channels * 8,
+                n=base_depth * 3, depthwise=depthwise, act=act,
+            ),
+        )
+
+        # dark5
+        self.dark5 = M.Sequential(
+            Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
+            SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
+            CSPLayer(
+                base_channels * 16, base_channels * 16, n=base_depth,
+                shortcut=False, depthwise=depthwise, act=act,
+            ),
+        )
+
+    def forward(self, x):
+        outputs = {}
+        x = self.stem(x)
+        outputs["stem"] = x
+        x = self.dark2(x)
+        outputs["dark2"] = x
+        x = self.dark3(x)
+        outputs["dark3"] = x
+        x = self.dark4(x)
+        outputs["dark4"] = x
+        x = self.dark5(x)
+        outputs["dark5"] = x
+        return {k: v for k, v in outputs.items() if k in self.out_features}
diff --git a/multimodal/YOLOX/demo/MegEngine/python/models/network_blocks.py b/multimodal/YOLOX/demo/MegEngine/python/models/network_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0e40d3f2aea5bbd00493311219821a7e5d5e8be
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/models/network_blocks.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import megengine.functional as F
+import megengine.module as M
+
+
+class UpSample(M.Module):
+
+    def __init__(self, scale_factor=2, mode="bilinear"):
+        super().__init__()
+        self.scale_factor = scale_factor
+        self.mode = mode
+
+    def forward(self, x):
+        return F.vision.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
+
+
+class SiLU(M.Module):
+    """export-friendly version of M.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * F.sigmoid(x)
+
+
+def get_activation(name="silu"):
+    if name == "silu":
+        module = SiLU()
+    elif name == "relu":
+        module = M.ReLU()
+    elif name == "lrelu":
+        module = M.LeakyReLU(0.1)
+    else:
+        raise AttributeError("Unsupported act type: {}".format(name))
+    return module
+
+
+class BaseConv(M.Module):
+    """A Conv2d -> Batchnorm -> silu/leaky relu block"""
+
+    def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"):
+        super().__init__()
+        # same padding
+        pad = (ksize - 1) // 2
+        self.conv = M.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size=ksize,
+            stride=stride,
+            padding=pad,
+            groups=groups,
+            bias=bias,
+        )
+        self.bn = M.BatchNorm2d(out_channels)
+        self.act = get_activation(act)
+
+    def forward(self, x):
+        return self.act(self.bn(self.conv(x)))
+
+    def fuseforward(self, x):
+        return self.act(self.conv(x))
+
+
+class DWConv(M.Module):
+    """Depthwise Conv + Conv"""
+    def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
+        super().__init__()
+        self.dconv = BaseConv(
+            in_channels, in_channels, ksize=ksize,
+            stride=stride, groups=in_channels, act=act
+        )
+        self.pconv = BaseConv(
+            in_channels, out_channels, ksize=1,
+            stride=1, groups=1, act=act
+        )
+
+    def forward(self, x):
+        x = self.dconv(x)
+        return self.pconv(x)
+
+
+class Bottleneck(M.Module):
+    # Standard bottleneck
+    def __init__(
+        self, in_channels, out_channels, shortcut=True,
+        expansion=0.5, depthwise=False, act="silu"
+    ):
+        super().__init__()
+        hidden_channels = int(out_channels * expansion)
+        Conv = DWConv if depthwise else BaseConv
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
+        self.use_add = shortcut and in_channels == out_channels
+
+    def forward(self, x):
+        y = self.conv2(self.conv1(x))
+        if self.use_add:
+            y = y + x
+        return y
+
+
+class ResLayer(M.Module):
+    "Residual layer with `in_channels` inputs."
+    def __init__(self, in_channels: int):
+        super().__init__()
+        mid_channels = in_channels // 2
+        self.layer1 = BaseConv(in_channels, mid_channels, ksize=1, stride=1, act="lrelu")
+        self.layer2 = BaseConv(mid_channels, in_channels, ksize=3, stride=1, act="lrelu")
+
+    def forward(self, x):
+        out = self.layer2(self.layer1(x))
+        return x + out
+
+
+class SPPBottleneck(M.Module):
+    """Spatial pyramid pooling layer used in YOLOv3-SPP"""
+    def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"):
+        super().__init__()
+        hidden_channels = in_channels // 2
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
+        self.m = [M.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes]
+        conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
+        self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = F.concat([x] + [m(x) for m in self.m], axis=1)
+        x = self.conv2(x)
+        return x
+
+
+class CSPLayer(M.Module):
+    """C3 in yolov5, CSP Bottleneck with 3 convolutions"""
+
+    def __init__(
+        self, in_channels, out_channels, n=1,
+        shortcut=True, expansion=0.5, depthwise=False, act="silu"
+    ):
+        """
+        Args:
+            in_channels (int): input channels.
+            out_channels (int): output channels.
+            n (int): number of Bottlenecks. Default value: 1.
+        """
+        # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        hidden_channels = int(out_channels * expansion)  # hidden channels
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
+        module_list = [
+            Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act)
+            for _ in range(n)
+        ]
+        self.m = M.Sequential(*module_list)
+
+    def forward(self, x):
+        x_1 = self.conv1(x)
+        x_2 = self.conv2(x)
+        x_1 = self.m(x_1)
+        x = F.concat((x_1, x_2), axis=1)
+        return self.conv3(x)
+
+
+class Focus(M.Module):
+    """Focus width and height information into channel space."""
+
+    def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
+        super().__init__()
+        self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
+
+    def forward(self, x):
+        # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
+        patch_top_left = x[..., ::2, ::2]
+        patch_top_right = x[..., ::2, 1::2]
+        patch_bot_left = x[..., 1::2, ::2]
+        patch_bot_right = x[..., 1::2, 1::2]
+        x = F.concat(
+            (patch_top_left, patch_bot_left, patch_top_right, patch_bot_right,), axis=1,
+        )
+        return self.conv(x)
diff --git a/multimodal/YOLOX/demo/MegEngine/python/models/yolo_fpn.py b/multimodal/YOLOX/demo/MegEngine/python/models/yolo_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..675a7f6e6b8e42ecc8eaf90cfb5b20939b1c3e0d
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/models/yolo_fpn.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import megengine.functional as F
+import megengine.module as M
+
+from .darknet import Darknet
+from .network_blocks import BaseConv, UpSample
+
+
+class YOLOFPN(M.Module):
+    """
+    YOLOFPN module. Darknet 53 is the default backbone of this model.
+    """
+
+    def __init__(
+        self, depth=53, in_features=["dark3", "dark4", "dark5"],
+    ):
+        super().__init__()
+
+        self.backbone = Darknet(depth)
+        self.in_features = in_features
+
+        # out 1
+        self.out1_cbl = self._make_cbl(512, 256, 1)
+        self.out1 = self._make_embedding([256, 512], 512 + 256)
+
+        # out 2
+        self.out2_cbl = self._make_cbl(256, 128, 1)
+        self.out2 = self._make_embedding([128, 256], 256 + 128)
+
+        # upsample
+        self.upsample = UpSample(scale_factor=2, mode="bilinear")
+
+    def _make_cbl(self, _in, _out, ks):
+        return BaseConv(_in, _out, ks, stride=1, act="lrelu")
+
+    def _make_embedding(self, filters_list, in_filters):
+        m = M.Sequential(
+            *[
+                self._make_cbl(in_filters, filters_list[0], 1),
+                self._make_cbl(filters_list[0], filters_list[1], 3),
+
+                self._make_cbl(filters_list[1], filters_list[0], 1),
+
+                self._make_cbl(filters_list[0], filters_list[1], 3),
+                self._make_cbl(filters_list[1], filters_list[0], 1),
+            ]
+        )
+        return m
+
+    def forward(self, inputs):
+        """
+        Args:
+            inputs (Tensor): input image.
+
+        Returns:
+            Tuple[Tensor]: FPN output features..
+        """
+        #  backbone
+        out_features = self.backbone(inputs)
+        x2, x1, x0 = [out_features[f] for f in self.in_features]
+
+        #  yolo branch 1
+        x1_in = self.out1_cbl(x0)
+        x1_in = self.upsample(x1_in)
+        x1_in = F.concat([x1_in, x1], 1)
+        out_dark4 = self.out1(x1_in)
+
+        #  yolo branch 2
+        x2_in = self.out2_cbl(out_dark4)
+        x2_in = self.upsample(x2_in)
+        x2_in = F.concat([x2_in, x2], 1)
+        out_dark3 = self.out2(x2_in)
+
+        outputs = (out_dark3, out_dark4, x0)
+        return outputs
diff --git a/multimodal/YOLOX/demo/MegEngine/python/models/yolo_head.py b/multimodal/YOLOX/demo/MegEngine/python/models/yolo_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bba674d55824bd166389453f7074f9613b49b28
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/models/yolo_head.py
@@ -0,0 +1,192 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import megengine.functional as F
+import megengine.module as M
+
+from .network_blocks import BaseConv, DWConv
+
+
+def meshgrid(x, y):
+    """meshgrid wrapper for megengine"""
+    assert len(x.shape) == 1
+    assert len(y.shape) == 1
+    mesh_shape = (y.shape[0], x.shape[0])
+    mesh_x = F.broadcast_to(x, mesh_shape)
+    mesh_y = F.broadcast_to(y.reshape(-1, 1), mesh_shape)
+    return mesh_x, mesh_y
+
+
+class YOLOXHead(M.Module):
+    def __init__(
+        self, num_classes, width=1.0, strides=[8, 16, 32],
+        in_channels=[256, 512, 1024], act="silu", depthwise=False
+    ):
+        """
+        Args:
+            act (str): activation type of conv. Defalut value: "silu".
+            depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
+        """
+        super().__init__()
+
+        self.n_anchors = 1
+        self.num_classes = num_classes
+        self.decode_in_inference = True  # save for matching
+
+        self.cls_convs = []
+        self.reg_convs = []
+        self.cls_preds = []
+        self.reg_preds = []
+        self.obj_preds = []
+        self.stems = []
+        Conv = DWConv if depthwise else BaseConv
+
+        for i in range(len(in_channels)):
+            self.stems.append(
+                BaseConv(
+                    in_channels=int(in_channels[i] * width),
+                    out_channels=int(256 * width),
+                    ksize=1,
+                    stride=1,
+                    act=act,
+                )
+            )
+            self.cls_convs.append(
+                M.Sequential(
+                    *[
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                    ]
+                )
+            )
+            self.reg_convs.append(
+                M.Sequential(
+                    *[
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                    ]
+                )
+            )
+            self.cls_preds.append(
+                M.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=self.n_anchors * self.num_classes,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+            self.reg_preds.append(
+                M.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=4,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+            self.obj_preds.append(
+                M.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=self.n_anchors * 1,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+
+        self.use_l1 = False
+        self.strides = strides
+        self.grids = [F.zeros(1)] * len(in_channels)
+
+    def forward(self, xin, labels=None, imgs=None):
+        outputs = []
+        assert not self.training
+
+        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
+            zip(self.cls_convs, self.reg_convs, self.strides, xin)
+        ):
+            x = self.stems[k](x)
+            cls_x = x
+            reg_x = x
+
+            cls_feat = cls_conv(cls_x)
+            cls_output = self.cls_preds[k](cls_feat)
+
+            reg_feat = reg_conv(reg_x)
+            reg_output = self.reg_preds[k](reg_feat)
+            obj_output = self.obj_preds[k](reg_feat)
+            output = F.concat([reg_output, F.sigmoid(obj_output), F.sigmoid(cls_output)], 1)
+            outputs.append(output)
+
+        self.hw = [x.shape[-2:] for x in outputs]
+        # [batch, n_anchors_all, 85]
+        outputs = F.concat([F.flatten(x, start_axis=2) for x in outputs], axis=2)
+        outputs = F.transpose(outputs, (0, 2, 1))
+        if self.decode_in_inference:
+            return self.decode_outputs(outputs)
+        else:
+            return outputs
+
+    def get_output_and_grid(self, output, k, stride, dtype):
+        grid = self.grids[k]
+
+        batch_size = output.shape[0]
+        n_ch = 5 + self.num_classes
+        hsize, wsize = output.shape[-2:]
+        if grid.shape[2:4] != output.shape[2:4]:
+            yv, xv = meshgrid([F.arange(hsize), F.arange(wsize)])
+            grid = F.stack((xv, yv), 2).reshape(1, 1, hsize, wsize, 2).type(dtype)
+            self.grids[k] = grid
+
+        output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize)
+        output = (
+            output.permute(0, 1, 3, 4, 2)
+            .reshape(batch_size, self.n_anchors * hsize * wsize, -1)
+        )
+        grid = grid.view(1, -1, 2)
+        output[..., :2] = (output[..., :2] + grid) * stride
+        output[..., 2:4] = F.exp(output[..., 2:4]) * stride
+        return output, grid
+
+    def decode_outputs(self, outputs):
+        grids = []
+        strides = []
+        for (hsize, wsize), stride in zip(self.hw, self.strides):
+            xv, yv = meshgrid(F.arange(hsize), F.arange(wsize))
+            grid = F.stack((xv, yv), 2).reshape(1, -1, 2)
+            grids.append(grid)
+            shape = grid.shape[:2]
+            strides.append(F.full((*shape, 1), stride))
+
+        grids = F.concat(grids, axis=1)
+        strides = F.concat(strides, axis=1)
+
+        outputs[..., :2] = (outputs[..., :2] + grids) * strides
+        outputs[..., 2:4] = F.exp(outputs[..., 2:4]) * strides
+        return outputs
diff --git a/multimodal/YOLOX/demo/MegEngine/python/models/yolo_pafpn.py b/multimodal/YOLOX/demo/MegEngine/python/models/yolo_pafpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..86154bfa92e8da44042fb2d152322725d0039040
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/models/yolo_pafpn.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import megengine.module as M
+import megengine.functional as F
+
+from .darknet import CSPDarknet
+from .network_blocks import BaseConv, CSPLayer, DWConv, UpSample
+
+
+class YOLOPAFPN(M.Module):
+    """
+    YOLOv3 model. Darknet 53 is the default backbone of this model.
+    """
+
+    def __init__(
+        self, depth=1.0, width=1.0, in_features=("dark3", "dark4", "dark5"),
+        in_channels=[256, 512, 1024], depthwise=False, act="silu",
+    ):
+        super().__init__()
+        self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
+        self.in_features = in_features
+        self.in_channels = in_channels
+        Conv = DWConv if depthwise else BaseConv
+
+        self.upsample = UpSample(scale_factor=2, mode="bilinear")
+        self.lateral_conv0 = BaseConv(
+            int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
+        )
+        self.C3_p4 = CSPLayer(
+            int(2 * in_channels[1] * width),
+            int(in_channels[1] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )  # cat
+
+        self.reduce_conv1 = BaseConv(
+            int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
+        )
+        self.C3_p3 = CSPLayer(
+            int(2 * in_channels[0] * width),
+            int(in_channels[0] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+        # bottom-up conv
+        self.bu_conv2 = Conv(
+            int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
+        )
+        self.C3_n3 = CSPLayer(
+            int(2 * in_channels[0] * width),
+            int(in_channels[1] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+        # bottom-up conv
+        self.bu_conv1 = Conv(
+            int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
+        )
+        self.C3_n4 = CSPLayer(
+            int(2 * in_channels[1] * width),
+            int(in_channels[2] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+    def forward(self, input):
+        """
+        Args:
+            inputs: input images.
+
+        Returns:
+            Tuple[Tensor]: FPN feature.
+        """
+
+        #  backbone
+        out_features = self.backbone(input)
+        features = [out_features[f] for f in self.in_features]
+        [x2, x1, x0] = features
+
+        fpn_out0 = self.lateral_conv0(x0)  # 1024->512/32
+        f_out0 = self.upsample(fpn_out0)  # 512/16
+        f_out0 = F.concat([f_out0, x1], 1)  # 512->1024/16
+        f_out0 = self.C3_p4(f_out0)  # 1024->512/16
+
+        fpn_out1 = self.reduce_conv1(f_out0)  # 512->256/16
+        f_out1 = self.upsample(fpn_out1)  # 256/8
+        f_out1 = F.concat([f_out1, x2], 1)  # 256->512/8
+        pan_out2 = self.C3_p3(f_out1)  # 512->256/8
+
+        p_out1 = self.bu_conv2(pan_out2)  # 256->256/16
+        p_out1 = F.concat([p_out1, fpn_out1], 1)  # 256->512/16
+        pan_out1 = self.C3_n3(p_out1)  # 512->512/16
+
+        p_out0 = self.bu_conv1(pan_out1)  # 512->512/32
+        p_out0 = F.concat([p_out0, fpn_out0], 1)  # 512->1024/32
+        pan_out0 = self.C3_n4(p_out0)  # 1024->1024/32
+
+        outputs = (pan_out2, pan_out1, pan_out0)
+        return outputs
diff --git a/multimodal/YOLOX/demo/MegEngine/python/models/yolox.py b/multimodal/YOLOX/demo/MegEngine/python/models/yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..657049fd36340381224938e224ffe729f39c9d90
--- /dev/null
+++ b/multimodal/YOLOX/demo/MegEngine/python/models/yolox.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import megengine.module as M
+
+from .yolo_head import YOLOXHead
+from .yolo_pafpn import YOLOPAFPN
+
+
+class YOLOX(M.Module):
+    """
+    YOLOX model module. The module list is defined by create_yolov3_modules function.
+    The network returns loss values from three YOLO layers during training
+    and detection results during test.
+    """
+
+    def __init__(self, backbone=None, head=None):
+        super().__init__()
+        if backbone is None:
+            backbone = YOLOPAFPN()
+        if head is None:
+            head = YOLOXHead(80)
+
+        self.backbone = backbone
+        self.head = head
+
+    def forward(self, x):
+        # fpn output content features of [dark3, dark4, dark5]
+        fpn_outs = self.backbone(x)
+        assert not self.training
+        outputs = self.head(fpn_outs)
+
+        return outputs
diff --git a/multimodal/YOLOX/demo/ONNXRuntime/README.md b/multimodal/YOLOX/demo/ONNXRuntime/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6af0944a6b3a984045daf2d4215f96290ed5e9af
--- /dev/null
+++ b/multimodal/YOLOX/demo/ONNXRuntime/README.md
@@ -0,0 +1,78 @@
+## YOLOX-ONNXRuntime in Python
+
+This doc introduces how to convert your pytorch model into onnx, and how to run an onnxruntime demo to verify your convertion.
+
+### Step1: Install onnxruntime
+
+run the following command to install onnxruntime:
+```shell
+pip install onnxruntime
+```
+
+### Step2: Get ONNX models
+
+Users might download our pre-generated ONNX models or convert their own models to ONNX.
+
+#### Download ONNX models.
+
+| Model | Parameters | GFLOPs | Test Size | mAP | Weights |
+|:------| :----: | :----: | :---: | :---: | :---: |
+|  YOLOX-Nano |  0.91M  | 1.08 | 416x416 | 25.8 |[github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_nano.onnx) |
+|  YOLOX-Tiny | 5.06M     | 6.45 | 416x416 |32.8 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_tiny.onnx) |
+|  YOLOX-S | 9.0M | 26.8 | 640x640 |40.5 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.onnx) |
+|  YOLOX-M | 25.3M | 73.8 | 640x640 |47.2 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_m.onnx) |
+|  YOLOX-L | 54.2M | 155.6 | 640x640 |50.1 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_l.onnx) |
+|  YOLOX-Darknet53| 63.72M | 185.3 | 640x640 |48.0 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_darknet.onnx) |
+|  YOLOX-X | 99.1M | 281.9 | 640x640 |51.5 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x.onnx) |
+
+#### Convert Your Model to ONNX
+
+First, you should move to <YOLOX_HOME> by:
+```shell
+cd <YOLOX_HOME>
+```
+Then, you can:
+
+1. Convert a standard YOLOX model by -n:
+```shell
+python3 tools/export_onnx.py --output-name yolox_s.onnx -n yolox-s -c yolox_s.pth
+```
+Notes:
+* -n: specify a model name. The model name must be one of the [yolox-s,m,l,x and yolox-nano, yolox-tiny, yolov3]
+* -c: the model you have trained
+* -o: opset version, default 11. **However, if you will further convert your onnx model to [OpenVINO](https://github.com/Megvii-BaseDetection/YOLOX/demo/OpenVINO/), please specify the opset version to 10.**
+* --no-onnxsim: disable onnxsim
+* To customize an input shape for onnx model,  modify the following code in tools/export.py:
+
+    ```python
+    dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1])
+    ```
+
+1. Convert a standard YOLOX model by -f. When using -f, the above command is equivalent to:
+
+```shell
+python3 tools/export_onnx.py --output-name yolox_s.onnx -f exps/default/yolox_s.py -c yolox_s.pth
+```
+
+3. To convert your customized model, please use -f:
+
+```shell
+python3 tools/export_onnx.py --output-name your_yolox.onnx -f exps/your_dir/your_yolox.py -c your_yolox.pth
+```
+
+### Step3: ONNXRuntime Demo
+
+Step1.
+```shell
+cd <YOLOX_HOME>/demo/ONNXRuntime
+```
+
+Step2. 
+```shell
+python3 onnx_inference.py -m <ONNX_MODEL_PATH> -i <IMAGE_PATH> -o <OUTPUT_DIR> -s 0.3 --input_shape 640,640
+```
+Notes:
+* -m: your converted onnx model
+* -i: input_image
+* -s: score threshold for visualization.
+* --input_shape: should be consistent with the shape you used for onnx convertion.
diff --git a/multimodal/YOLOX/demo/ONNXRuntime/onnx_inference.py b/multimodal/YOLOX/demo/ONNXRuntime/onnx_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..07654dc9b981d5640274254cc945bad0bbaa1cdf
--- /dev/null
+++ b/multimodal/YOLOX/demo/ONNXRuntime/onnx_inference.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+
+import cv2
+import numpy as np
+
+import onnxruntime
+
+from yolox.data.data_augment import preproc as preprocess
+from yolox.data.datasets import COCO_CLASSES
+from yolox.utils import mkdir, multiclass_nms, demo_postprocess, vis
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("onnxruntime inference sample")
+    parser.add_argument(
+        "-m",
+        "--model",
+        type=str,
+        default="yolox.onnx",
+        help="Input your onnx model.",
+    )
+    parser.add_argument(
+        "-i",
+        "--image_path",
+        type=str,
+        default='test_image.png',
+        help="Path to your input image.",
+    )
+    parser.add_argument(
+        "-o",
+        "--output_dir",
+        type=str,
+        default='demo_output',
+        help="Path to your output directory.",
+    )
+    parser.add_argument(
+        "-s",
+        "--score_thr",
+        type=float,
+        default=0.3,
+        help="Score threshould to filter the result.",
+    )
+    parser.add_argument(
+        "--input_shape",
+        type=str,
+        default="640,640",
+        help="Specify an input shape for inference.",
+    )
+    return parser
+
+
+if __name__ == '__main__':
+    args = make_parser().parse_args()
+
+    input_shape = tuple(map(int, args.input_shape.split(',')))
+    origin_img = cv2.imread(args.image_path)
+    img, ratio = preprocess(origin_img, input_shape)
+
+    session = onnxruntime.InferenceSession(args.model)
+
+    ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
+    output = session.run(None, ort_inputs)
+    predictions = demo_postprocess(output[0], input_shape)[0]
+
+    boxes = predictions[:, :4]
+    scores = predictions[:, 4:5] * predictions[:, 5:]
+
+    boxes_xyxy = np.ones_like(boxes)
+    boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
+    boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
+    boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
+    boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
+    boxes_xyxy /= ratio
+    dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
+    if dets is not None:
+        final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
+        origin_img = vis(origin_img, final_boxes, final_scores, final_cls_inds,
+                         conf=args.score_thr, class_names=COCO_CLASSES)
+
+    mkdir(args.output_dir)
+    output_path = os.path.join(args.output_dir, os.path.basename(args.image_path))
+    cv2.imwrite(output_path, origin_img)
diff --git a/multimodal/YOLOX/demo/OpenVINO/README.md b/multimodal/YOLOX/demo/OpenVINO/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..559708f13f2f21bbb16ae331f50a625014a7b28b
--- /dev/null
+++ b/multimodal/YOLOX/demo/OpenVINO/README.md
@@ -0,0 +1,4 @@
+## YOLOX for OpenVINO
+
+* [C++ Demo](./cpp)
+* [Python Demo](./python)
\ No newline at end of file
diff --git a/multimodal/YOLOX/demo/OpenVINO/cpp/CMakeLists.txt b/multimodal/YOLOX/demo/OpenVINO/cpp/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..614739bda584016b5b46cfc356ba94d23be43464
--- /dev/null
+++ b/multimodal/YOLOX/demo/OpenVINO/cpp/CMakeLists.txt
@@ -0,0 +1,23 @@
+cmake_minimum_required(VERSION 3.4.1)
+set(CMAKE_CXX_STANDARD 14)
+
+project(yolox_openvino_demo)
+
+find_package(OpenCV REQUIRED)
+find_package(InferenceEngine REQUIRED)
+find_package(ngraph REQUIRED)
+
+include_directories(
+    ${OpenCV_INCLUDE_DIRS}
+    ${CMAKE_CURRENT_SOURCE_DIR}
+    ${CMAKE_CURRENT_BINARY_DIR}
+)
+
+add_executable(yolox_openvino yolox_openvino.cpp)
+
+target_link_libraries(
+     yolox_openvino
+    ${InferenceEngine_LIBRARIES}
+    ${NGRAPH_LIBRARIES}
+    ${OpenCV_LIBS} 
+)
\ No newline at end of file
diff --git a/multimodal/YOLOX/demo/OpenVINO/cpp/README.md b/multimodal/YOLOX/demo/OpenVINO/cpp/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c877d94c2834da117c49df41aa936614c175c6df
--- /dev/null
+++ b/multimodal/YOLOX/demo/OpenVINO/cpp/README.md
@@ -0,0 +1,97 @@
+# YOLOX-OpenVINO in C++
+
+This tutorial includes a C++ demo for OpenVINO, as well as some converted models.
+
+### Download OpenVINO models.
+
+| Model | Parameters | GFLOPs | Test Size | mAP | Weights |
+|:------| :----: | :----: | :---: | :---: | :---: |
+|  [YOLOX-Nano](../../../exps/default/nano.py) |  0.91M  | 1.08 | 416x416 | 25.8 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_nano_openvino.tar.gz) |
+|  [YOLOX-Tiny](../../../exps/default/yolox_tiny.py) | 5.06M     | 6.45 | 416x416 |32.8 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_tiny_openvino.tar.gz) |
+|  [YOLOX-S](../../../exps/default/yolox_s.py) | 9.0M | 26.8 | 640x640 |40.5 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s_openvino.tar.gz) |
+|  [YOLOX-M](../../../exps/default/yolox_m.py) | 25.3M | 73.8 | 640x640 |47.2 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_m_openvino.tar.gz) |
+|  [YOLOX-L](../../../exps/default/yolox_l.py) | 54.2M | 155.6 | 640x640 |50.1 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_l_openvino.tar.gz) |
+|  [YOLOX-Darknet53](../../../exps/default/yolov3.py) | 63.72M | 185.3 | 640x640 |48.0 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_dark_openvino.tar.gz) | 
+|  [YOLOX-X](../../../exps/default/yolox_x.py) | 99.1M | 281.9 | 640x640 |51.5 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x_openvino.tar.gz) |
+
+## Install OpenVINO Toolkit
+
+Please visit [Openvino Homepage](https://docs.openvinotoolkit.org/latest/get_started_guides.html) for more details.
+
+## Set up the Environment
+
+### For Linux
+
+**Option1. Set up the environment tempororally. You need to run this command everytime you start a new shell window.**
+
+```shell
+source /opt/intel/openvino_2021/bin/setupvars.sh
+```
+
+**Option2. Set up the environment permenantly.**
+
+*Step1.* For Linux:
+```shell
+vim ~/.bashrc 
+```
+
+*Step2.* Add the following line into your file:
+
+```shell
+source /opt/intel/openvino_2021/bin/setupvars.sh
+```
+
+*Step3.* Save and exit the file, then run:
+
+```shell
+source ~/.bashrc
+```
+
+
+## Convert model
+
+1. Export ONNX model
+   
+   Please refer to the [ONNX tutorial](../../ONNXRuntime). **Note that you should set --opset to 10, otherwise your next step will fail.**
+
+2. Convert ONNX to OpenVINO 
+
+   ``` shell
+   cd <INSTSLL_DIR>/openvino_2021/deployment_tools/model_optimizer
+   ```
+
+   Install requirements for convert tool
+
+   ```shell
+   sudo ./install_prerequisites/install_prerequisites_onnx.sh
+   ```
+
+   Then convert model.
+   ```shell
+   python3 mo.py --input_model <ONNX_MODEL> --input_shape <INPUT_SHAPE> [--data_type FP16]
+   ```
+   For example:
+   ```shell
+   python3 mo.py --input_model yolox_tiny.onnx --input_shape [1,3,416,416] --data_type FP16
+   ```  
+
+   Make sure the input shape is consistent with [those](yolox_openvino.cpp#L24-L25) in cpp file. 
+
+## Build 
+
+### Linux
+```shell
+source /opt/intel/openvino_2021/bin/setupvars.sh
+mkdir build
+cd build
+cmake ..
+make
+```
+
+## Demo
+
+### c++
+
+```shell
+./yolox_openvino <XML_MODEL_PATH> <IMAGE_PATH> <DEVICE>
+```
diff --git a/multimodal/YOLOX/demo/OpenVINO/cpp/yolox_openvino.cpp b/multimodal/YOLOX/demo/OpenVINO/cpp/yolox_openvino.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f42344141cd760737e9d2b617d776480d4379a7d
--- /dev/null
+++ b/multimodal/YOLOX/demo/OpenVINO/cpp/yolox_openvino.cpp
@@ -0,0 +1,529 @@
+// Copyright (C) 2018-2021 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include <iterator>
+#include <memory>
+#include <string>
+#include <vector>
+#include <opencv2/opencv.hpp>
+#include <iostream>
+#include <inference_engine.hpp>
+
+using namespace InferenceEngine;
+
+/**
+ * @brief Define names based depends on Unicode path support
+ */
+#define tcout                  std::cout
+#define file_name_t            std::string
+#define imread_t               cv::imread
+#define NMS_THRESH 0.45
+#define BBOX_CONF_THRESH 0.3
+
+static const int INPUT_W = 416;
+static const int INPUT_H = 416;
+static const int NUM_CLASSES = 80; // COCO has 80 classes. Modify this value on your own dataset.
+
+cv::Mat static_resize(cv::Mat& img) {
+    float r = std::min(INPUT_W / (img.cols*1.0), INPUT_H / (img.rows*1.0));
+    // r = std::min(r, 1.0f);
+    int unpad_w = r * img.cols;
+    int unpad_h = r * img.rows;
+    cv::Mat re(unpad_h, unpad_w, CV_8UC3);
+    cv::resize(img, re, re.size());
+    //cv::Mat out(INPUT_W, INPUT_H, CV_8UC3, cv::Scalar(114, 114, 114));
+    cv::Mat out(INPUT_H, INPUT_W, CV_8UC3, cv::Scalar(114, 114, 114));
+    re.copyTo(out(cv::Rect(0, 0, re.cols, re.rows)));
+    return out;
+}
+
+void blobFromImage(cv::Mat& img, Blob::Ptr& blob){
+    int channels = 3;
+    int img_h = img.rows;
+    int img_w = img.cols;
+    InferenceEngine::MemoryBlob::Ptr mblob = InferenceEngine::as<InferenceEngine::MemoryBlob>(blob);
+    if (!mblob) 
+    {
+        THROW_IE_EXCEPTION << "We expect blob to be inherited from MemoryBlob in matU8ToBlob, "
+            << "but by fact we were not able to cast inputBlob to MemoryBlob";
+    }
+    // locked memory holder should be alive all time while access to its buffer happens
+    auto mblobHolder = mblob->wmap();
+
+    float *blob_data = mblobHolder.as<float *>();
+
+    for (size_t c = 0; c < channels; c++) 
+    {
+        for (size_t  h = 0; h < img_h; h++) 
+        {
+            for (size_t w = 0; w < img_w; w++) 
+            {
+                blob_data[c * img_w * img_h + h * img_w + w] =
+                    (float)img.at<cv::Vec3b>(h, w)[c];
+            }
+        }
+    }
+}
+
+
+struct Object
+{
+    cv::Rect_<float> rect;
+    int label;
+    float prob;
+};
+
+struct GridAndStride
+{
+    int grid0;
+    int grid1;
+    int stride;
+};
+
+static void generate_grids_and_stride(const int target_w, const int target_h, std::vector<int>& strides, std::vector<GridAndStride>& grid_strides)
+{
+    for (auto stride : strides)
+    {
+        int num_grid_w = target_w / stride;
+        int num_grid_h = target_h / stride;
+        for (int g1 = 0; g1 < num_grid_h; g1++)
+        {
+            for (int g0 = 0; g0 < num_grid_w; g0++)
+            {
+                grid_strides.push_back((GridAndStride){g0, g1, stride});
+            }
+        }
+    }
+}
+
+
+static void generate_yolox_proposals(std::vector<GridAndStride> grid_strides, const float* feat_ptr, float prob_threshold, std::vector<Object>& objects)
+{
+
+    const int num_anchors = grid_strides.size();
+
+    for (int anchor_idx = 0; anchor_idx < num_anchors; anchor_idx++)
+    {
+        const int grid0 = grid_strides[anchor_idx].grid0;
+        const int grid1 = grid_strides[anchor_idx].grid1;
+        const int stride = grid_strides[anchor_idx].stride;
+
+	const int basic_pos = anchor_idx * (NUM_CLASSES + 5);
+
+        // yolox/models/yolo_head.py decode logic
+        //  outputs[..., :2] = (outputs[..., :2] + grids) * strides
+        //  outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
+        float x_center = (feat_ptr[basic_pos + 0] + grid0) * stride;
+        float y_center = (feat_ptr[basic_pos + 1] + grid1) * stride;
+        float w = exp(feat_ptr[basic_pos + 2]) * stride;
+        float h = exp(feat_ptr[basic_pos + 3]) * stride;
+        float x0 = x_center - w * 0.5f;
+        float y0 = y_center - h * 0.5f;
+
+        float box_objectness = feat_ptr[basic_pos + 4];
+        for (int class_idx = 0; class_idx < NUM_CLASSES; class_idx++)
+        {
+            float box_cls_score = feat_ptr[basic_pos + 5 + class_idx];
+            float box_prob = box_objectness * box_cls_score;
+            if (box_prob > prob_threshold)
+            {
+                Object obj;
+                obj.rect.x = x0;
+                obj.rect.y = y0;
+                obj.rect.width = w;
+                obj.rect.height = h;
+                obj.label = class_idx;
+                obj.prob = box_prob;
+
+                objects.push_back(obj);
+            }
+
+        } // class loop
+
+    } // point anchor loop
+}
+
+static inline float intersection_area(const Object& a, const Object& b)
+{
+    cv::Rect_<float> inter = a.rect & b.rect;
+    return inter.area();
+}
+
+static void qsort_descent_inplace(std::vector<Object>& faceobjects, int left, int right)
+{
+    int i = left;
+    int j = right;
+    float p = faceobjects[(left + right) / 2].prob;
+
+    while (i <= j)
+    {
+        while (faceobjects[i].prob > p)
+            i++;
+
+        while (faceobjects[j].prob < p)
+            j--;
+
+        if (i <= j)
+        {
+            // swap
+            std::swap(faceobjects[i], faceobjects[j]);
+
+            i++;
+            j--;
+        }
+    }
+
+    #pragma omp parallel sections
+    {
+        #pragma omp section
+        {
+            if (left < j) qsort_descent_inplace(faceobjects, left, j);
+        }
+        #pragma omp section
+        {
+            if (i < right) qsort_descent_inplace(faceobjects, i, right);
+        }
+    }
+}
+
+
+static void qsort_descent_inplace(std::vector<Object>& objects)
+{
+    if (objects.empty())
+        return;
+
+    qsort_descent_inplace(objects, 0, objects.size() - 1);
+}
+
+static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold)
+{
+    picked.clear();
+
+    const int n = faceobjects.size();
+
+    std::vector<float> areas(n);
+    for (int i = 0; i < n; i++)
+    {
+        areas[i] = faceobjects[i].rect.area();
+    }
+
+    for (int i = 0; i < n; i++)
+    {
+        const Object& a = faceobjects[i];
+
+        int keep = 1;
+        for (int j = 0; j < (int)picked.size(); j++)
+        {
+            const Object& b = faceobjects[picked[j]];
+
+            // intersection over union
+            float inter_area = intersection_area(a, b);
+            float union_area = areas[i] + areas[picked[j]] - inter_area;
+            // float IoU = inter_area / union_area
+            if (inter_area / union_area > nms_threshold)
+                keep = 0;
+        }
+
+        if (keep)
+            picked.push_back(i);
+    }
+}
+
+
+static void decode_outputs(const float* prob, std::vector<Object>& objects, float scale, const int img_w, const int img_h) {
+        std::vector<Object> proposals;
+        std::vector<int> strides = {8, 16, 32};
+        std::vector<GridAndStride> grid_strides;
+
+        generate_grids_and_stride(INPUT_W, INPUT_H, strides, grid_strides);
+        generate_yolox_proposals(grid_strides, prob,  BBOX_CONF_THRESH, proposals);
+        qsort_descent_inplace(proposals);
+
+        std::vector<int> picked;
+        nms_sorted_bboxes(proposals, picked, NMS_THRESH);
+        int count = picked.size();
+        objects.resize(count);
+
+        for (int i = 0; i < count; i++)
+        {
+            objects[i] = proposals[picked[i]];
+
+            // adjust offset to original unpadded
+            float x0 = (objects[i].rect.x) / scale;
+            float y0 = (objects[i].rect.y) / scale;
+            float x1 = (objects[i].rect.x + objects[i].rect.width) / scale;
+            float y1 = (objects[i].rect.y + objects[i].rect.height) / scale;
+
+            // clip
+            x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
+            y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
+            x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
+            y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);
+
+            objects[i].rect.x = x0;
+            objects[i].rect.y = y0;
+            objects[i].rect.width = x1 - x0;
+            objects[i].rect.height = y1 - y0;
+        }
+}
+
+const float color_list[80][3] =
+{
+    {0.000, 0.447, 0.741},
+    {0.850, 0.325, 0.098},
+    {0.929, 0.694, 0.125},
+    {0.494, 0.184, 0.556},
+    {0.466, 0.674, 0.188},
+    {0.301, 0.745, 0.933},
+    {0.635, 0.078, 0.184},
+    {0.300, 0.300, 0.300},
+    {0.600, 0.600, 0.600},
+    {1.000, 0.000, 0.000},
+    {1.000, 0.500, 0.000},
+    {0.749, 0.749, 0.000},
+    {0.000, 1.000, 0.000},
+    {0.000, 0.000, 1.000},
+    {0.667, 0.000, 1.000},
+    {0.333, 0.333, 0.000},
+    {0.333, 0.667, 0.000},
+    {0.333, 1.000, 0.000},
+    {0.667, 0.333, 0.000},
+    {0.667, 0.667, 0.000},
+    {0.667, 1.000, 0.000},
+    {1.000, 0.333, 0.000},
+    {1.000, 0.667, 0.000},
+    {1.000, 1.000, 0.000},
+    {0.000, 0.333, 0.500},
+    {0.000, 0.667, 0.500},
+    {0.000, 1.000, 0.500},
+    {0.333, 0.000, 0.500},
+    {0.333, 0.333, 0.500},
+    {0.333, 0.667, 0.500},
+    {0.333, 1.000, 0.500},
+    {0.667, 0.000, 0.500},
+    {0.667, 0.333, 0.500},
+    {0.667, 0.667, 0.500},
+    {0.667, 1.000, 0.500},
+    {1.000, 0.000, 0.500},
+    {1.000, 0.333, 0.500},
+    {1.000, 0.667, 0.500},
+    {1.000, 1.000, 0.500},
+    {0.000, 0.333, 1.000},
+    {0.000, 0.667, 1.000},
+    {0.000, 1.000, 1.000},
+    {0.333, 0.000, 1.000},
+    {0.333, 0.333, 1.000},
+    {0.333, 0.667, 1.000},
+    {0.333, 1.000, 1.000},
+    {0.667, 0.000, 1.000},
+    {0.667, 0.333, 1.000},
+    {0.667, 0.667, 1.000},
+    {0.667, 1.000, 1.000},
+    {1.000, 0.000, 1.000},
+    {1.000, 0.333, 1.000},
+    {1.000, 0.667, 1.000},
+    {0.333, 0.000, 0.000},
+    {0.500, 0.000, 0.000},
+    {0.667, 0.000, 0.000},
+    {0.833, 0.000, 0.000},
+    {1.000, 0.000, 0.000},
+    {0.000, 0.167, 0.000},
+    {0.000, 0.333, 0.000},
+    {0.000, 0.500, 0.000},
+    {0.000, 0.667, 0.000},
+    {0.000, 0.833, 0.000},
+    {0.000, 1.000, 0.000},
+    {0.000, 0.000, 0.167},
+    {0.000, 0.000, 0.333},
+    {0.000, 0.000, 0.500},
+    {0.000, 0.000, 0.667},
+    {0.000, 0.000, 0.833},
+    {0.000, 0.000, 1.000},
+    {0.000, 0.000, 0.000},
+    {0.143, 0.143, 0.143},
+    {0.286, 0.286, 0.286},
+    {0.429, 0.429, 0.429},
+    {0.571, 0.571, 0.571},
+    {0.714, 0.714, 0.714},
+    {0.857, 0.857, 0.857},
+    {0.000, 0.447, 0.741},
+    {0.314, 0.717, 0.741},
+    {0.50, 0.5, 0}
+};
+
+static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
+{
+    static const char* class_names[] = {
+        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
+        "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
+        "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
+        "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
+        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
+        "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
+        "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
+        "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
+        "hair drier", "toothbrush"
+    };
+
+    cv::Mat image = bgr.clone();
+
+    for (size_t i = 0; i < objects.size(); i++)
+    {
+        const Object& obj = objects[i];
+
+        fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
+                obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
+
+        cv::Scalar color = cv::Scalar(color_list[obj.label][0], color_list[obj.label][1], color_list[obj.label][2]);
+        float c_mean = cv::mean(color)[0];
+        cv::Scalar txt_color;
+        if (c_mean > 0.5){
+            txt_color = cv::Scalar(0, 0, 0);
+        }else{
+            txt_color = cv::Scalar(255, 255, 255);
+        }
+
+        cv::rectangle(image, obj.rect, color * 255, 2);
+
+        char text[256];
+        sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
+
+        int baseLine = 0;
+        cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
+
+        cv::Scalar txt_bk_color = color * 0.7 * 255;
+
+        int x = obj.rect.x;
+        int y = obj.rect.y + 1;
+        //int y = obj.rect.y - label_size.height - baseLine;
+        if (y > image.rows)
+            y = image.rows;
+        //if (x + label_size.width > image.cols)
+            //x = image.cols - label_size.width;
+
+        cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
+                      txt_bk_color, -1);
+
+        cv::putText(image, text, cv::Point(x, y + label_size.height),
+                    cv::FONT_HERSHEY_SIMPLEX, 0.4, txt_color, 1);
+    }
+
+    cv::imwrite("_demo.jpg" , image);
+    fprintf(stderr, "save vis file\n");
+    /* cv::imshow("image", image); */
+    /* cv::waitKey(0); */
+}
+
+
+int main(int argc, char* argv[]) {
+    try {
+        // ------------------------------ Parsing and validation of input arguments
+        // ---------------------------------
+        if (argc != 4) {
+            tcout << "Usage : " << argv[0] << " <path_to_model> <path_to_image> <device_name>" << std::endl;
+            return EXIT_FAILURE;
+        }
+
+        const file_name_t input_model {argv[1]};
+        const file_name_t input_image_path {argv[2]};
+        const std::string device_name {argv[3]};
+        // -----------------------------------------------------------------------------------------------------
+
+        // --------------------------- Step 1. Initialize inference engine core
+        // -------------------------------------
+        Core ie;
+        // -----------------------------------------------------------------------------------------------------
+
+        // Step 2. Read a model in OpenVINO Intermediate Representation (.xml and
+        // .bin files) or ONNX (.onnx file) format
+        CNNNetwork network = ie.ReadNetwork(input_model);
+        if (network.getOutputsInfo().size() != 1)
+            throw std::logic_error("Sample supports topologies with 1 output only");
+        if (network.getInputsInfo().size() != 1)
+            throw std::logic_error("Sample supports topologies with 1 input only");
+        // -----------------------------------------------------------------------------------------------------
+
+        // --------------------------- Step 3. Configure input & output
+        // ---------------------------------------------
+        // --------------------------- Prepare input blobs
+        // -----------------------------------------------------
+        InputInfo::Ptr input_info = network.getInputsInfo().begin()->second;
+        std::string input_name = network.getInputsInfo().begin()->first;
+
+        /* Mark input as resizable by setting of a resize algorithm.
+         * In this case we will be able to set an input blob of any shape to an
+         * infer request. Resize and layout conversions are executed automatically
+         * during inference */
+        //input_info->getPreProcess().setResizeAlgorithm(RESIZE_BILINEAR);
+        //input_info->setLayout(Layout::NHWC);
+        //input_info->setPrecision(Precision::FP32);
+
+        // --------------------------- Prepare output blobs
+        // ----------------------------------------------------
+        if (network.getOutputsInfo().empty()) {
+            std::cerr << "Network outputs info is empty" << std::endl;
+            return EXIT_FAILURE;
+        }
+        DataPtr output_info = network.getOutputsInfo().begin()->second;
+        std::string output_name = network.getOutputsInfo().begin()->first;
+
+        output_info->setPrecision(Precision::FP32);
+        // -----------------------------------------------------------------------------------------------------
+
+        // --------------------------- Step 4. Loading a model to the device
+        // ------------------------------------------
+        ExecutableNetwork executable_network = ie.LoadNetwork(network, device_name);
+        // -----------------------------------------------------------------------------------------------------
+
+        // --------------------------- Step 5. Create an infer request
+        // -------------------------------------------------
+        InferRequest infer_request = executable_network.CreateInferRequest();
+        // -----------------------------------------------------------------------------------------------------
+
+        // --------------------------- Step 6. Prepare input
+        // --------------------------------------------------------
+        /* Read input image to a blob and set it to an infer request without resize
+         * and layout conversions. */
+        cv::Mat image = imread_t(input_image_path);
+	    cv::Mat pr_img = static_resize(image);
+        Blob::Ptr imgBlob = infer_request.GetBlob(input_name);     // just wrap Mat data by Blob::Ptr
+	    blobFromImage(pr_img, imgBlob);
+
+        // infer_request.SetBlob(input_name, imgBlob);  // infer_request accepts input blob of any size
+        // -----------------------------------------------------------------------------------------------------
+
+        // --------------------------- Step 7. Do inference
+        // --------------------------------------------------------
+        /* Running the request synchronously */
+        infer_request.Infer();
+        // -----------------------------------------------------------------------------------------------------
+
+        // --------------------------- Step 8. Process output
+        // ------------------------------------------------------
+        const Blob::Ptr output_blob = infer_request.GetBlob(output_name);
+        MemoryBlob::CPtr moutput = as<MemoryBlob>(output_blob);
+        if (!moutput) {
+            throw std::logic_error("We expect output to be inherited from MemoryBlob, "
+                                   "but by fact we were not able to cast output to MemoryBlob");
+        }
+        // locked memory holder should be alive all time while access to its buffer
+        // happens
+        auto moutputHolder = moutput->rmap();
+        const float* net_pred = moutputHolder.as<const PrecisionTrait<Precision::FP32>::value_type*>();
+        
+	    int img_w = image.cols;
+        int img_h = image.rows;
+	    float scale = std::min(INPUT_W / (image.cols*1.0), INPUT_H / (image.rows*1.0));
+        std::vector<Object> objects;
+
+        decode_outputs(net_pred, objects, scale, img_w, img_h);
+        draw_objects(image, objects);
+
+            // -----------------------------------------------------------------------------------------------------
+        } catch (const std::exception& ex) {
+            std::cerr << ex.what() << std::endl;
+            return EXIT_FAILURE;
+    }
+    return EXIT_SUCCESS;
+}
diff --git a/multimodal/YOLOX/demo/OpenVINO/python/README.md b/multimodal/YOLOX/demo/OpenVINO/python/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bbaf5aca44e86523c428735745848d2839351552
--- /dev/null
+++ b/multimodal/YOLOX/demo/OpenVINO/python/README.md
@@ -0,0 +1,89 @@
+# YOLOX-OpenVINO in Python
+
+This tutorial includes a Python demo for OpenVINO, as well as some converted models.
+
+### Download OpenVINO models.
+
+| Model | Parameters | GFLOPs | Test Size | mAP | Weights |
+|:------| :----: | :----: | :---: | :---: | :---: |
+|  [YOLOX-Nano](../../../exps/default/nano.py) |  0.91M  | 1.08 | 416x416 | 25.8 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_nano_openvino.tar.gz) |
+|  [YOLOX-Tiny](../../../exps/default/yolox_tiny.py) | 5.06M     | 6.45 | 416x416 |32.8 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_tiny_openvino.tar.gz) |
+|  [YOLOX-S](../../../exps/default/yolox_s.py) | 9.0M | 26.8 | 640x640 |40.5 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s_openvino.tar.gz) |
+|  [YOLOX-M](../../../exps/default/yolox_m.py) | 25.3M | 73.8 | 640x640 |47.2 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_m_openvino.tar.gz) |
+|  [YOLOX-L](../../../exps/default/yolox_l.py) | 54.2M | 155.6 | 640x640 |50.1 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_l_openvino.tar.gz) |
+|  [YOLOX-Darknet53](../../../exps/default/yolov3.py) | 63.72M | 185.3 | 640x640 |48.0 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_dark_openvino.tar.gz) | 
+|  [YOLOX-X](../../../exps/default/yolox_x.py) | 99.1M | 281.9 | 640x640 |51.5 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x_openvino.tar.gz) |
+
+## Install OpenVINO Toolkit
+
+Please visit [Openvino Homepage](https://docs.openvinotoolkit.org/latest/get_started_guides.html) for more details.
+
+## Set up the Environment
+
+### For Linux
+
+**Option1. Set up the environment tempororally. You need to run this command everytime you start a new shell window.**
+
+```shell
+source /opt/intel/openvino_2021/bin/setupvars.sh
+```
+
+**Option2. Set up the environment permenantly.**
+
+*Step1.* For Linux:
+```shell
+vim ~/.bashrc
+```
+
+*Step2.* Add the following line into your file:
+
+```shell
+source /opt/intel/openvino_2021/bin/setupvars.sh
+```
+
+*Step3.* Save and exit the file, then run:
+
+```shell
+source ~/.bashrc
+```
+
+
+## Convert model
+
+1. Export ONNX model
+
+   Please refer to the [ONNX tutorial](https://github.com/Megvii-BaseDetection/YOLOX/demo/ONNXRuntime). **Note that you should set --opset to 10, otherwise your next step will fail.**
+
+2. Convert ONNX to OpenVINO
+
+   ``` shell
+   cd <INSTSLL_DIR>/openvino_2021/deployment_tools/model_optimizer
+   ```
+
+   Install requirements for convert tool
+
+   ```shell
+   sudo ./install_prerequisites/install_prerequisites_onnx.sh
+   ```
+
+   Then convert model.
+   ```shell
+   python3 mo.py --input_model <ONNX_MODEL> --input_shape <INPUT_SHAPE> [--data_type FP16]
+   ```
+   For example:
+   ```shell
+   python3 mo.py --input_model yolox.onnx --input_shape [1,3,640,640] --data_type FP16 --output_dir converted_output
+   ```
+
+## Demo
+
+### python
+
+```shell
+python openvino_inference.py -m <XML_MODEL_PATH> -i <IMAGE_PATH> 
+```
+or
+```shell
+python openvino_inference.py -m <XML_MODEL_PATH> -i <IMAGE_PATH> -o <OUTPUT_DIR> -s <SCORE_THR> -d <DEVICE>
+```
+
diff --git a/multimodal/YOLOX/demo/OpenVINO/python/openvino_inference.py b/multimodal/YOLOX/demo/OpenVINO/python/openvino_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..00952880043c8b24c738324ee3f527aca7774f75
--- /dev/null
+++ b/multimodal/YOLOX/demo/OpenVINO/python/openvino_inference.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright (C) 2018-2021 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import logging as log
+import os
+import sys
+
+import cv2
+import numpy as np
+
+from openvino.inference_engine import IECore
+
+from yolox.data.data_augment import preproc as preprocess
+from yolox.data.datasets import COCO_CLASSES
+from yolox.utils import mkdir, multiclass_nms, demo_postprocess, vis
+
+
+def parse_args() -> argparse.Namespace:
+    """Parse and return command line arguments"""
+    parser = argparse.ArgumentParser(add_help=False)
+    args = parser.add_argument_group('Options')
+    args.add_argument(
+        '-h',
+        '--help',
+        action='help',
+        help='Show this help message and exit.')
+    args.add_argument(
+        '-m',
+        '--model',
+        required=True,
+        type=str,
+        help='Required. Path to an .xml or .onnx file with a trained model.')
+    args.add_argument(
+        '-i',
+        '--input',
+        required=True,
+        type=str,
+        help='Required. Path to an image file.')
+    args.add_argument(
+        '-o',
+        '--output_dir',
+        type=str,
+        default='demo_output',
+        help='Path to your output dir.')
+    args.add_argument(
+        '-s',
+        '--score_thr',
+        type=float,
+        default=0.3,
+        help="Score threshould to visualize the result.")
+    args.add_argument(
+        '-d',
+        '--device',
+        default='CPU',
+        type=str,
+        help='Optional. Specify the target device to infer on; CPU, GPU, \
+              MYRIAD, HDDL or HETERO: is acceptable. The sample will look \
+              for a suitable plugin for device specified. Default value \
+              is CPU.')
+    args.add_argument(
+        '--labels',
+        default=None,
+        type=str,
+        help='Option:al. Path to a labels mapping file.')
+    args.add_argument(
+        '-nt',
+        '--number_top',
+        default=10,
+        type=int,
+        help='Optional. Number of top results.')
+    return parser.parse_args()
+
+
+def main():
+    log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
+    args = parse_args()
+
+    # ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
+    log.info('Creating Inference Engine')
+    ie = IECore()
+
+    # ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
+    log.info(f'Reading the network: {args.model}')
+    # (.xml and .bin files) or (.onnx file)
+    net = ie.read_network(model=args.model)
+
+    if len(net.input_info) != 1:
+        log.error('Sample supports only single input topologies')
+        return -1
+    if len(net.outputs) != 1:
+        log.error('Sample supports only single output topologies')
+        return -1
+
+    # ---------------------------Step 3. Configure input & output----------------------------------------------------------
+    log.info('Configuring input and output blobs')
+    # Get names of input and output blobs
+    input_blob = next(iter(net.input_info))
+    out_blob = next(iter(net.outputs))
+
+    # Set input and output precision manually
+    net.input_info[input_blob].precision = 'FP32'
+    net.outputs[out_blob].precision = 'FP16'
+
+    # Get a number of classes recognized by a model
+    num_of_classes = max(net.outputs[out_blob].shape)
+
+    # ---------------------------Step 4. Loading model to the device-------------------------------------------------------
+    log.info('Loading the model to the plugin')
+    exec_net = ie.load_network(network=net, device_name=args.device)
+
+    # ---------------------------Step 5. Create infer request--------------------------------------------------------------
+    # load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
+    # instance which stores infer requests. So you already created Infer requests in the previous step.
+
+    # ---------------------------Step 6. Prepare input---------------------------------------------------------------------
+    origin_img = cv2.imread(args.input)
+    _, _, h, w = net.input_info[input_blob].input_data.shape
+    image, ratio = preprocess(origin_img, (h, w))
+
+    # ---------------------------Step 7. Do inference----------------------------------------------------------------------
+    log.info('Starting inference in synchronous mode')
+    res = exec_net.infer(inputs={input_blob: image})
+
+    # ---------------------------Step 8. Process output--------------------------------------------------------------------
+    res = res[out_blob]
+
+    predictions = demo_postprocess(res, (h, w))[0]
+
+    boxes = predictions[:, :4]
+    scores = predictions[:, 4, None] * predictions[:, 5:]
+
+    boxes_xyxy = np.ones_like(boxes)
+    boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
+    boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
+    boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
+    boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
+    boxes_xyxy /= ratio
+    dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
+
+    if dets is not None:
+        final_boxes = dets[:, :4]
+        final_scores, final_cls_inds = dets[:, 4], dets[:, 5]
+        origin_img = vis(origin_img, final_boxes, final_scores, final_cls_inds,
+                         conf=args.score_thr, class_names=COCO_CLASSES)
+
+    mkdir(args.output_dir)
+    output_path = os.path.join(args.output_dir, os.path.basename(args.input))
+    cv2.imwrite(output_path, origin_img)
+
+
+if __name__ == '__main__':
+    sys.exit(main())
diff --git a/multimodal/YOLOX/demo/TensorRT/cpp/CMakeLists.txt b/multimodal/YOLOX/demo/TensorRT/cpp/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5f14edd594a5c9106bc5d50bc352c5bb14f716a4
--- /dev/null
+++ b/multimodal/YOLOX/demo/TensorRT/cpp/CMakeLists.txt
@@ -0,0 +1,36 @@
+cmake_minimum_required(VERSION 2.6)
+
+project(yolox)
+
+add_definitions(-std=c++11)
+
+option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
+set(CMAKE_CXX_STANDARD 11)
+set(CMAKE_BUILD_TYPE Debug)
+
+find_package(CUDA REQUIRED)
+
+include_directories(${PROJECT_SOURCE_DIR}/include)
+# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
+# cuda
+include_directories(/data/cuda/cuda-10.2/cuda/include)
+link_directories(/data/cuda/cuda-10.2/cuda/lib64)
+# cudnn
+include_directories(/data/cuda/cuda-10.2/cudnn/v8.0.4/include)
+link_directories(/data/cuda/cuda-10.2/cudnn/v8.0.4/lib64)
+# tensorrt
+include_directories(/data/cuda/cuda-10.2/TensorRT/v7.2.1.6/include)
+link_directories(/data/cuda/cuda-10.2/TensorRT/v7.2.1.6/lib)
+
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED")
+
+find_package(OpenCV)
+include_directories(${OpenCV_INCLUDE_DIRS})
+
+add_executable(yolox ${PROJECT_SOURCE_DIR}/yolox.cpp)
+target_link_libraries(yolox nvinfer)
+target_link_libraries(yolox cudart)
+target_link_libraries(yolox ${OpenCV_LIBS})
+
+add_definitions(-O2 -pthread)
+
diff --git a/multimodal/YOLOX/demo/TensorRT/cpp/README.md b/multimodal/YOLOX/demo/TensorRT/cpp/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0158e7dacdf0af0d427a917e83adf8e7b4e02fac
--- /dev/null
+++ b/multimodal/YOLOX/demo/TensorRT/cpp/README.md
@@ -0,0 +1,48 @@
+# YOLOX-TensorRT in C++
+
+As YOLOX models are easy to convert to tensorrt using [torch2trt gitrepo](https://github.com/NVIDIA-AI-IOT/torch2trt), 
+our C++ demo does not include the model converting or constructing like other tenorrt demos.
+
+
+## Step 1: Prepare serialized engine file
+
+Follow the trt [python demo README](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/demo/TensorRT/python/README.md) to convert and save the serialized engine file.
+
+Check the 'model_trt.engine' file generated from Step 1, which will be automatically saved at the current demo dir.
+
+
+## Step 2: build the demo
+
+Please follow the [TensorRT Installation Guide](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) to install TensorRT.
+
+And you should set the TensorRT path and CUDA path in CMakeLists.txt.
+
+If you train your custom dataset, you may need to modify the value of `num_class`.
+
+```c++
+const int num_class = 80;
+```
+
+Install opencv with ```sudo apt-get install libopencv-dev``` (we don't need a higher version of opencv like v3.3+). 
+
+build the demo:
+
+```shell
+mkdir build
+cd build
+cmake ..
+make
+```
+
+Then run the demo:
+
+```shell
+./yolox ../model_trt.engine -i ../../../../assets/dog.jpg
+```
+
+or
+
+```shell
+./yolox <path/to/your/engine_file> -i <path/to/image>
+```
+
diff --git a/multimodal/YOLOX/demo/TensorRT/cpp/logging.h b/multimodal/YOLOX/demo/TensorRT/cpp/logging.h
new file mode 100644
index 0000000000000000000000000000000000000000..0edb75fab69b539b755422263c6f474576e21ee6
--- /dev/null
+++ b/multimodal/YOLOX/demo/TensorRT/cpp/logging.h
@@ -0,0 +1,503 @@
+/*
+ * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef TENSORRT_LOGGING_H
+#define TENSORRT_LOGGING_H
+
+#include "NvInferRuntimeCommon.h"
+#include <cassert>
+#include <ctime>
+#include <iomanip>
+#include <iostream>
+#include <ostream>
+#include <sstream>
+#include <string>
+
+using Severity = nvinfer1::ILogger::Severity;
+
+class LogStreamConsumerBuffer : public std::stringbuf
+{
+public:
+    LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
+        : mOutput(stream)
+        , mPrefix(prefix)
+        , mShouldLog(shouldLog)
+    {
+    }
+
+    LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
+        : mOutput(other.mOutput)
+    {
+    }
+
+    ~LogStreamConsumerBuffer()
+    {
+        // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
+        // std::streambuf::pptr() gives a pointer to the current position of the output sequence
+        // if the pointer to the beginning is not equal to the pointer to the current position,
+        // call putOutput() to log the output to the stream
+        if (pbase() != pptr())
+        {
+            putOutput();
+        }
+    }
+
+    // synchronizes the stream buffer and returns 0 on success
+    // synchronizing the stream buffer consists of inserting the buffer contents into the stream,
+    // resetting the buffer and flushing the stream
+    virtual int sync()
+    {
+        putOutput();
+        return 0;
+    }
+
+    void putOutput()
+    {
+        if (mShouldLog)
+        {
+            // prepend timestamp
+            std::time_t timestamp = std::time(nullptr);
+            tm* tm_local = std::localtime(&timestamp);
+            std::cout << "[";
+            std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
+            std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
+            std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
+            std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
+            std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
+            std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
+            // std::stringbuf::str() gets the string contents of the buffer
+            // insert the buffer contents pre-appended by the appropriate prefix into the stream
+            mOutput << mPrefix << str();
+            // set the buffer to empty
+            str("");
+            // flush the stream
+            mOutput.flush();
+        }
+    }
+
+    void setShouldLog(bool shouldLog)
+    {
+        mShouldLog = shouldLog;
+    }
+
+private:
+    std::ostream& mOutput;
+    std::string mPrefix;
+    bool mShouldLog;
+};
+
+//!
+//! \class LogStreamConsumerBase
+//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
+//!
+class LogStreamConsumerBase
+{
+public:
+    LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
+        : mBuffer(stream, prefix, shouldLog)
+    {
+    }
+
+protected:
+    LogStreamConsumerBuffer mBuffer;
+};
+
+//!
+//! \class LogStreamConsumer
+//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
+//!  Order of base classes is LogStreamConsumerBase and then std::ostream.
+//!  This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
+//!  in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
+//!  This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
+//!  Please do not change the order of the parent classes.
+//!
+class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
+{
+public:
+    //! \brief Creates a LogStreamConsumer which logs messages with level severity.
+    //!  Reportable severity determines if the messages are severe enough to be logged.
+    LogStreamConsumer(Severity reportableSeverity, Severity severity)
+        : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
+        , std::ostream(&mBuffer) // links the stream buffer with the stream
+        , mShouldLog(severity <= reportableSeverity)
+        , mSeverity(severity)
+    {
+    }
+
+    LogStreamConsumer(LogStreamConsumer&& other)
+        : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
+        , std::ostream(&mBuffer) // links the stream buffer with the stream
+        , mShouldLog(other.mShouldLog)
+        , mSeverity(other.mSeverity)
+    {
+    }
+
+    void setReportableSeverity(Severity reportableSeverity)
+    {
+        mShouldLog = mSeverity <= reportableSeverity;
+        mBuffer.setShouldLog(mShouldLog);
+    }
+
+private:
+    static std::ostream& severityOstream(Severity severity)
+    {
+        return severity >= Severity::kINFO ? std::cout : std::cerr;
+    }
+
+    static std::string severityPrefix(Severity severity)
+    {
+        switch (severity)
+        {
+        case Severity::kINTERNAL_ERROR: return "[F] ";
+        case Severity::kERROR: return "[E] ";
+        case Severity::kWARNING: return "[W] ";
+        case Severity::kINFO: return "[I] ";
+        case Severity::kVERBOSE: return "[V] ";
+        default: assert(0); return "";
+        }
+    }
+
+    bool mShouldLog;
+    Severity mSeverity;
+};
+
+//! \class Logger
+//!
+//! \brief Class which manages logging of TensorRT tools and samples
+//!
+//! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
+//! and supports logging two types of messages:
+//!
+//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
+//! - Test pass/fail messages
+//!
+//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
+//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
+//!
+//! In the future, this class could be extended to support dumping test results to a file in some standard format
+//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
+//!
+//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
+//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
+//! library and messages coming from the sample.
+//!
+//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
+//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
+//! object.
+
+class Logger : public nvinfer1::ILogger
+{
+public:
+    Logger(Severity severity = Severity::kWARNING)
+        : mReportableSeverity(severity)
+    {
+    }
+
+    //!
+    //! \enum TestResult
+    //! \brief Represents the state of a given test
+    //!
+    enum class TestResult
+    {
+        kRUNNING, //!< The test is running
+        kPASSED,  //!< The test passed
+        kFAILED,  //!< The test failed
+        kWAIVED   //!< The test was waived
+    };
+
+    //!
+    //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
+    //! \return The nvinfer1::ILogger associated with this Logger
+    //!
+    //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
+    //! we can eliminate the inheritance of Logger from ILogger
+    //!
+    nvinfer1::ILogger& getTRTLogger()
+    {
+        return *this;
+    }
+
+    //!
+    //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
+    //!
+    //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
+    //! inheritance from nvinfer1::ILogger
+    //!
+    void log(Severity severity, const char* msg) noexcept override
+    {
+        LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
+    }
+
+    //!
+    //! \brief Method for controlling the verbosity of logging output
+    //!
+    //! \param severity The logger will only emit messages that have severity of this level or higher.
+    //!
+    void setReportableSeverity(Severity severity)
+    {
+        mReportableSeverity = severity;
+    }
+
+    //!
+    //! \brief Opaque handle that holds logging information for a particular test
+    //!
+    //! This object is an opaque handle to information used by the Logger to print test results.
+    //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
+    //! with Logger::reportTest{Start,End}().
+    //!
+    class TestAtom
+    {
+    public:
+        TestAtom(TestAtom&&) = default;
+
+    private:
+        friend class Logger;
+
+        TestAtom(bool started, const std::string& name, const std::string& cmdline)
+            : mStarted(started)
+            , mName(name)
+            , mCmdline(cmdline)
+        {
+        }
+
+        bool mStarted;
+        std::string mName;
+        std::string mCmdline;
+    };
+
+    //!
+    //! \brief Define a test for logging
+    //!
+    //! \param[in] name The name of the test.  This should be a string starting with
+    //!                  "TensorRT" and containing dot-separated strings containing
+    //!                  the characters [A-Za-z0-9_].
+    //!                  For example, "TensorRT.sample_googlenet"
+    //! \param[in] cmdline The command line used to reproduce the test
+    //
+    //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
+    //!
+    static TestAtom defineTest(const std::string& name, const std::string& cmdline)
+    {
+        return TestAtom(false, name, cmdline);
+    }
+
+    //!
+    //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
+    //!        as input
+    //!
+    //! \param[in] name The name of the test
+    //! \param[in] argc The number of command-line arguments
+    //! \param[in] argv The array of command-line arguments (given as C strings)
+    //!
+    //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
+    static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
+    {
+        auto cmdline = genCmdlineString(argc, argv);
+        return defineTest(name, cmdline);
+    }
+
+    //!
+    //! \brief Report that a test has started.
+    //!
+    //! \pre reportTestStart() has not been called yet for the given testAtom
+    //!
+    //! \param[in] testAtom The handle to the test that has started
+    //!
+    static void reportTestStart(TestAtom& testAtom)
+    {
+        reportTestResult(testAtom, TestResult::kRUNNING);
+        assert(!testAtom.mStarted);
+        testAtom.mStarted = true;
+    }
+
+    //!
+    //! \brief Report that a test has ended.
+    //!
+    //! \pre reportTestStart() has been called for the given testAtom
+    //!
+    //! \param[in] testAtom The handle to the test that has ended
+    //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
+    //!                   TestResult::kFAILED, TestResult::kWAIVED
+    //!
+    static void reportTestEnd(const TestAtom& testAtom, TestResult result)
+    {
+        assert(result != TestResult::kRUNNING);
+        assert(testAtom.mStarted);
+        reportTestResult(testAtom, result);
+    }
+
+    static int reportPass(const TestAtom& testAtom)
+    {
+        reportTestEnd(testAtom, TestResult::kPASSED);
+        return EXIT_SUCCESS;
+    }
+
+    static int reportFail(const TestAtom& testAtom)
+    {
+        reportTestEnd(testAtom, TestResult::kFAILED);
+        return EXIT_FAILURE;
+    }
+
+    static int reportWaive(const TestAtom& testAtom)
+    {
+        reportTestEnd(testAtom, TestResult::kWAIVED);
+        return EXIT_SUCCESS;
+    }
+
+    static int reportTest(const TestAtom& testAtom, bool pass)
+    {
+        return pass ? reportPass(testAtom) : reportFail(testAtom);
+    }
+
+    Severity getReportableSeverity() const
+    {
+        return mReportableSeverity;
+    }
+
+private:
+    //!
+    //! \brief returns an appropriate string for prefixing a log message with the given severity
+    //!
+    static const char* severityPrefix(Severity severity)
+    {
+        switch (severity)
+        {
+        case Severity::kINTERNAL_ERROR: return "[F] ";
+        case Severity::kERROR: return "[E] ";
+        case Severity::kWARNING: return "[W] ";
+        case Severity::kINFO: return "[I] ";
+        case Severity::kVERBOSE: return "[V] ";
+        default: assert(0); return "";
+        }
+    }
+
+    //!
+    //! \brief returns an appropriate string for prefixing a test result message with the given result
+    //!
+    static const char* testResultString(TestResult result)
+    {
+        switch (result)
+        {
+        case TestResult::kRUNNING: return "RUNNING";
+        case TestResult::kPASSED: return "PASSED";
+        case TestResult::kFAILED: return "FAILED";
+        case TestResult::kWAIVED: return "WAIVED";
+        default: assert(0); return "";
+        }
+    }
+
+    //!
+    //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
+    //!
+    static std::ostream& severityOstream(Severity severity)
+    {
+        return severity >= Severity::kINFO ? std::cout : std::cerr;
+    }
+
+    //!
+    //! \brief method that implements logging test results
+    //!
+    static void reportTestResult(const TestAtom& testAtom, TestResult result)
+    {
+        severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
+                                         << testAtom.mCmdline << std::endl;
+    }
+
+    //!
+    //! \brief generate a command line string from the given (argc, argv) values
+    //!
+    static std::string genCmdlineString(int argc, char const* const* argv)
+    {
+        std::stringstream ss;
+        for (int i = 0; i < argc; i++)
+        {
+            if (i > 0)
+                ss << " ";
+            ss << argv[i];
+        }
+        return ss.str();
+    }
+
+    Severity mReportableSeverity;
+};
+
+namespace
+{
+
+//!
+//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
+//!
+//! Example usage:
+//!
+//!     LOG_VERBOSE(logger) << "hello world" << std::endl;
+//!
+inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
+{
+    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
+}
+
+//!
+//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
+//!
+//! Example usage:
+//!
+//!     LOG_INFO(logger) << "hello world" << std::endl;
+//!
+inline LogStreamConsumer LOG_INFO(const Logger& logger)
+{
+    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
+}
+
+//!
+//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
+//!
+//! Example usage:
+//!
+//!     LOG_WARN(logger) << "hello world" << std::endl;
+//!
+inline LogStreamConsumer LOG_WARN(const Logger& logger)
+{
+    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
+}
+
+//!
+//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
+//!
+//! Example usage:
+//!
+//!     LOG_ERROR(logger) << "hello world" << std::endl;
+//!
+inline LogStreamConsumer LOG_ERROR(const Logger& logger)
+{
+    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
+}
+
+//!
+//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
+//         ("fatal" severity)
+//!
+//! Example usage:
+//!
+//!     LOG_FATAL(logger) << "hello world" << std::endl;
+//!
+inline LogStreamConsumer LOG_FATAL(const Logger& logger)
+{
+    return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
+}
+
+} // anonymous namespace
+
+#endif // TENSORRT_LOGGING_H
diff --git a/multimodal/YOLOX/demo/TensorRT/cpp/yolox.cpp b/multimodal/YOLOX/demo/TensorRT/cpp/yolox.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ed423380ef35b4c39bf3231bac6e0079f7eea589
--- /dev/null
+++ b/multimodal/YOLOX/demo/TensorRT/cpp/yolox.cpp
@@ -0,0 +1,530 @@
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <numeric>
+#include <chrono>
+#include <vector>
+#include <opencv2/opencv.hpp>
+#include <dirent.h>
+#include "NvInfer.h"
+#include "cuda_runtime_api.h"
+#include "logging.h"
+
+#define CHECK(status) \
+    do\
+    {\
+        auto ret = (status);\
+        if (ret != 0)\
+        {\
+            std::cerr << "Cuda failure: " << ret << std::endl;\
+            abort();\
+        }\
+    } while (0)
+
+#define DEVICE 0  // GPU id
+#define NMS_THRESH 0.45
+#define BBOX_CONF_THRESH 0.3
+
+using namespace nvinfer1;
+
+// stuff we know about the network and the input/output blobs
+static const int INPUT_W = 640;
+static const int INPUT_H = 640;
+static const int NUM_CLASSES = 80;
+const char* INPUT_BLOB_NAME = "input_0";
+const char* OUTPUT_BLOB_NAME = "output_0";
+static Logger gLogger;
+
+cv::Mat static_resize(cv::Mat& img) {
+    float r = std::min(INPUT_W / (img.cols*1.0), INPUT_H / (img.rows*1.0));
+    // r = std::min(r, 1.0f);
+    int unpad_w = r * img.cols;
+    int unpad_h = r * img.rows;
+    cv::Mat re(unpad_h, unpad_w, CV_8UC3);
+    cv::resize(img, re, re.size());
+    cv::Mat out(INPUT_H, INPUT_W, CV_8UC3, cv::Scalar(114, 114, 114));
+    re.copyTo(out(cv::Rect(0, 0, re.cols, re.rows)));
+    return out;
+}
+
+struct Object
+{
+    cv::Rect_<float> rect;
+    int label;
+    float prob;
+};
+
+struct GridAndStride
+{
+    int grid0;
+    int grid1;
+    int stride;
+};
+
+static void generate_grids_and_stride(std::vector<int>& strides, std::vector<GridAndStride>& grid_strides)
+{
+    for (auto stride : strides)
+    {
+        int num_grid_y = INPUT_H / stride;
+        int num_grid_x = INPUT_W / stride;
+        for (int g1 = 0; g1 < num_grid_y; g1++)
+        {
+            for (int g0 = 0; g0 < num_grid_x; g0++)
+            {
+                grid_strides.push_back((GridAndStride){g0, g1, stride});
+            }
+        }
+    }
+}
+
+static inline float intersection_area(const Object& a, const Object& b)
+{
+    cv::Rect_<float> inter = a.rect & b.rect;
+    return inter.area();
+}
+
+static void qsort_descent_inplace(std::vector<Object>& faceobjects, int left, int right)
+{
+    int i = left;
+    int j = right;
+    float p = faceobjects[(left + right) / 2].prob;
+
+    while (i <= j)
+    {
+        while (faceobjects[i].prob > p)
+            i++;
+
+        while (faceobjects[j].prob < p)
+            j--;
+
+        if (i <= j)
+        {
+            // swap
+            std::swap(faceobjects[i], faceobjects[j]);
+
+            i++;
+            j--;
+        }
+    }
+
+    #pragma omp parallel sections
+    {
+        #pragma omp section
+        {
+            if (left < j) qsort_descent_inplace(faceobjects, left, j);
+        }
+        #pragma omp section
+        {
+            if (i < right) qsort_descent_inplace(faceobjects, i, right);
+        }
+    }
+}
+
+static void qsort_descent_inplace(std::vector<Object>& objects)
+{
+    if (objects.empty())
+        return;
+
+    qsort_descent_inplace(objects, 0, objects.size() - 1);
+}
+
+static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold)
+{
+    picked.clear();
+
+    const int n = faceobjects.size();
+
+    std::vector<float> areas(n);
+    for (int i = 0; i < n; i++)
+    {
+        areas[i] = faceobjects[i].rect.area();
+    }
+
+    for (int i = 0; i < n; i++)
+    {
+        const Object& a = faceobjects[i];
+
+        int keep = 1;
+        for (int j = 0; j < (int)picked.size(); j++)
+        {
+            const Object& b = faceobjects[picked[j]];
+
+            // intersection over union
+            float inter_area = intersection_area(a, b);
+            float union_area = areas[i] + areas[picked[j]] - inter_area;
+            // float IoU = inter_area / union_area
+            if (inter_area / union_area > nms_threshold)
+                keep = 0;
+        }
+
+        if (keep)
+            picked.push_back(i);
+    }
+}
+
+
+static void generate_yolox_proposals(std::vector<GridAndStride> grid_strides, float* feat_blob, float prob_threshold, std::vector<Object>& objects)
+{
+
+    const int num_anchors = grid_strides.size();
+
+    for (int anchor_idx = 0; anchor_idx < num_anchors; anchor_idx++)
+    {
+        const int grid0 = grid_strides[anchor_idx].grid0;
+        const int grid1 = grid_strides[anchor_idx].grid1;
+        const int stride = grid_strides[anchor_idx].stride;
+
+        const int basic_pos = anchor_idx * (NUM_CLASSES + 5);
+
+        // yolox/models/yolo_head.py decode logic
+        float x_center = (feat_blob[basic_pos+0] + grid0) * stride;
+        float y_center = (feat_blob[basic_pos+1] + grid1) * stride;
+        float w = exp(feat_blob[basic_pos+2]) * stride;
+        float h = exp(feat_blob[basic_pos+3]) * stride;
+        float x0 = x_center - w * 0.5f;
+        float y0 = y_center - h * 0.5f;
+
+        float box_objectness = feat_blob[basic_pos+4];
+        for (int class_idx = 0; class_idx < NUM_CLASSES; class_idx++)
+        {
+            float box_cls_score = feat_blob[basic_pos + 5 + class_idx];
+            float box_prob = box_objectness * box_cls_score;
+            if (box_prob > prob_threshold)
+            {
+                Object obj;
+                obj.rect.x = x0;
+                obj.rect.y = y0;
+                obj.rect.width = w;
+                obj.rect.height = h;
+                obj.label = class_idx;
+                obj.prob = box_prob;
+
+                objects.push_back(obj);
+            }
+
+        } // class loop
+
+    } // point anchor loop
+}
+
+float* blobFromImage(cv::Mat& img){
+    float* blob = new float[img.total()*3];
+    int channels = 3;
+    int img_h = img.rows;
+    int img_w = img.cols;
+    for (size_t c = 0; c < channels; c++) 
+    {
+        for (size_t  h = 0; h < img_h; h++) 
+        {
+            for (size_t w = 0; w < img_w; w++) 
+            {
+                blob[c * img_w * img_h + h * img_w + w] =
+                    (float)img.at<cv::Vec3b>(h, w)[c];
+            }
+        }
+    }
+    return blob;
+}
+
+
+static void decode_outputs(float* prob, std::vector<Object>& objects, float scale, const int img_w, const int img_h) {
+        std::vector<Object> proposals;
+        std::vector<int> strides = {8, 16, 32};
+        std::vector<GridAndStride> grid_strides;
+        generate_grids_and_stride(strides, grid_strides);
+        generate_yolox_proposals(grid_strides, prob,  BBOX_CONF_THRESH, proposals);
+        std::cout << "num of boxes before nms: " << proposals.size() << std::endl;
+
+        qsort_descent_inplace(proposals);
+
+        std::vector<int> picked;
+        nms_sorted_bboxes(proposals, picked, NMS_THRESH);
+
+
+        int count = picked.size();
+
+        std::cout << "num of boxes: " << count << std::endl;
+
+        objects.resize(count);
+        for (int i = 0; i < count; i++)
+        {
+            objects[i] = proposals[picked[i]];
+
+            // adjust offset to original unpadded
+            float x0 = (objects[i].rect.x) / scale;
+            float y0 = (objects[i].rect.y) / scale;
+            float x1 = (objects[i].rect.x + objects[i].rect.width) / scale;
+            float y1 = (objects[i].rect.y + objects[i].rect.height) / scale;
+
+            // clip
+            x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
+            y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
+            x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
+            y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);
+
+            objects[i].rect.x = x0;
+            objects[i].rect.y = y0;
+            objects[i].rect.width = x1 - x0;
+            objects[i].rect.height = y1 - y0;
+        }
+}
+
+const float color_list[80][3] =
+{
+    {0.000, 0.447, 0.741},
+    {0.850, 0.325, 0.098},
+    {0.929, 0.694, 0.125},
+    {0.494, 0.184, 0.556},
+    {0.466, 0.674, 0.188},
+    {0.301, 0.745, 0.933},
+    {0.635, 0.078, 0.184},
+    {0.300, 0.300, 0.300},
+    {0.600, 0.600, 0.600},
+    {1.000, 0.000, 0.000},
+    {1.000, 0.500, 0.000},
+    {0.749, 0.749, 0.000},
+    {0.000, 1.000, 0.000},
+    {0.000, 0.000, 1.000},
+    {0.667, 0.000, 1.000},
+    {0.333, 0.333, 0.000},
+    {0.333, 0.667, 0.000},
+    {0.333, 1.000, 0.000},
+    {0.667, 0.333, 0.000},
+    {0.667, 0.667, 0.000},
+    {0.667, 1.000, 0.000},
+    {1.000, 0.333, 0.000},
+    {1.000, 0.667, 0.000},
+    {1.000, 1.000, 0.000},
+    {0.000, 0.333, 0.500},
+    {0.000, 0.667, 0.500},
+    {0.000, 1.000, 0.500},
+    {0.333, 0.000, 0.500},
+    {0.333, 0.333, 0.500},
+    {0.333, 0.667, 0.500},
+    {0.333, 1.000, 0.500},
+    {0.667, 0.000, 0.500},
+    {0.667, 0.333, 0.500},
+    {0.667, 0.667, 0.500},
+    {0.667, 1.000, 0.500},
+    {1.000, 0.000, 0.500},
+    {1.000, 0.333, 0.500},
+    {1.000, 0.667, 0.500},
+    {1.000, 1.000, 0.500},
+    {0.000, 0.333, 1.000},
+    {0.000, 0.667, 1.000},
+    {0.000, 1.000, 1.000},
+    {0.333, 0.000, 1.000},
+    {0.333, 0.333, 1.000},
+    {0.333, 0.667, 1.000},
+    {0.333, 1.000, 1.000},
+    {0.667, 0.000, 1.000},
+    {0.667, 0.333, 1.000},
+    {0.667, 0.667, 1.000},
+    {0.667, 1.000, 1.000},
+    {1.000, 0.000, 1.000},
+    {1.000, 0.333, 1.000},
+    {1.000, 0.667, 1.000},
+    {0.333, 0.000, 0.000},
+    {0.500, 0.000, 0.000},
+    {0.667, 0.000, 0.000},
+    {0.833, 0.000, 0.000},
+    {1.000, 0.000, 0.000},
+    {0.000, 0.167, 0.000},
+    {0.000, 0.333, 0.000},
+    {0.000, 0.500, 0.000},
+    {0.000, 0.667, 0.000},
+    {0.000, 0.833, 0.000},
+    {0.000, 1.000, 0.000},
+    {0.000, 0.000, 0.167},
+    {0.000, 0.000, 0.333},
+    {0.000, 0.000, 0.500},
+    {0.000, 0.000, 0.667},
+    {0.000, 0.000, 0.833},
+    {0.000, 0.000, 1.000},
+    {0.000, 0.000, 0.000},
+    {0.143, 0.143, 0.143},
+    {0.286, 0.286, 0.286},
+    {0.429, 0.429, 0.429},
+    {0.571, 0.571, 0.571},
+    {0.714, 0.714, 0.714},
+    {0.857, 0.857, 0.857},
+    {0.000, 0.447, 0.741},
+    {0.314, 0.717, 0.741},
+    {0.50, 0.5, 0}
+};
+
+static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects, std::string f)
+{
+    static const char* class_names[] = {
+        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
+        "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
+        "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
+        "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
+        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
+        "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
+        "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
+        "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
+        "hair drier", "toothbrush"
+    };
+
+    cv::Mat image = bgr.clone();
+
+    for (size_t i = 0; i < objects.size(); i++)
+    {
+        const Object& obj = objects[i];
+
+        fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
+                obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
+
+        cv::Scalar color = cv::Scalar(color_list[obj.label][0], color_list[obj.label][1], color_list[obj.label][2]);
+        float c_mean = cv::mean(color)[0];
+        cv::Scalar txt_color;
+        if (c_mean > 0.5){
+            txt_color = cv::Scalar(0, 0, 0);
+        }else{
+            txt_color = cv::Scalar(255, 255, 255);
+        }
+
+        cv::rectangle(image, obj.rect, color * 255, 2);
+
+        char text[256];
+        sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
+
+        int baseLine = 0;
+        cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.4, 1, &baseLine);
+
+        cv::Scalar txt_bk_color = color * 0.7 * 255;
+
+        int x = obj.rect.x;
+        int y = obj.rect.y + 1;
+        //int y = obj.rect.y - label_size.height - baseLine;
+        if (y > image.rows)
+            y = image.rows;
+        //if (x + label_size.width > image.cols)
+            //x = image.cols - label_size.width;
+
+        cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
+                      txt_bk_color, -1);
+
+        cv::putText(image, text, cv::Point(x, y + label_size.height),
+                    cv::FONT_HERSHEY_SIMPLEX, 0.4, txt_color, 1);
+    }
+
+    cv::imwrite("det_res.jpg", image);
+    fprintf(stderr, "save vis file\n");
+    /* cv::imshow("image", image); */
+    /* cv::waitKey(0); */
+}
+
+
+void doInference(IExecutionContext& context, float* input, float* output, const int output_size, cv::Size input_shape) {
+    const ICudaEngine& engine = context.getEngine();
+
+    // Pointers to input and output device buffers to pass to engine.
+    // Engine requires exactly IEngine::getNbBindings() number of buffers.
+    assert(engine.getNbBindings() == 2);
+    void* buffers[2];
+
+    // In order to bind the buffers, we need to know the names of the input and output tensors.
+    // Note that indices are guaranteed to be less than IEngine::getNbBindings()
+    const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
+
+    assert(engine.getBindingDataType(inputIndex) == nvinfer1::DataType::kFLOAT);
+    const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);
+    assert(engine.getBindingDataType(outputIndex) == nvinfer1::DataType::kFLOAT);
+    int mBatchSize = engine.getMaxBatchSize();
+
+    // Create GPU buffers on device
+    CHECK(cudaMalloc(&buffers[inputIndex], 3 * input_shape.height * input_shape.width * sizeof(float)));
+    CHECK(cudaMalloc(&buffers[outputIndex], output_size*sizeof(float)));
+
+    // Create stream
+    cudaStream_t stream;
+    CHECK(cudaStreamCreate(&stream));
+
+    // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
+    CHECK(cudaMemcpyAsync(buffers[inputIndex], input, 3 * input_shape.height * input_shape.width * sizeof(float), cudaMemcpyHostToDevice, stream));
+    context.enqueue(1, buffers, stream, nullptr);
+    CHECK(cudaMemcpyAsync(output, buffers[outputIndex], output_size * sizeof(float), cudaMemcpyDeviceToHost, stream));
+    cudaStreamSynchronize(stream);
+
+    // Release stream and buffers
+    cudaStreamDestroy(stream);
+    CHECK(cudaFree(buffers[inputIndex]));
+    CHECK(cudaFree(buffers[outputIndex]));
+}
+
+int main(int argc, char** argv) {
+    cudaSetDevice(DEVICE);
+    // create a model using the API directly and serialize it to a stream
+    char *trtModelStream{nullptr};
+    size_t size{0};
+
+    if (argc == 4 && std::string(argv[2]) == "-i") {
+        const std::string engine_file_path {argv[1]};
+        std::ifstream file(engine_file_path, std::ios::binary);
+        if (file.good()) {
+            file.seekg(0, file.end);
+            size = file.tellg();
+            file.seekg(0, file.beg);
+            trtModelStream = new char[size];
+            assert(trtModelStream);
+            file.read(trtModelStream, size);
+            file.close();
+        }
+    } else {
+        std::cerr << "arguments not right!" << std::endl;
+        std::cerr << "run 'python3 yolox/deploy/trt.py -n yolox-{tiny, s, m, l, x}' to serialize model first!" << std::endl;
+        std::cerr << "Then use the following command:" << std::endl;
+        std::cerr << "./yolox ../model_trt.engine -i ../../../assets/dog.jpg  // deserialize file and run inference" << std::endl;
+        return -1;
+    }
+    const std::string input_image_path {argv[3]};
+
+    //std::vector<std::string> file_names;
+    //if (read_files_in_dir(argv[2], file_names) < 0) {
+        //std::cout << "read_files_in_dir failed." << std::endl;
+        //return -1;
+    //}
+
+    IRuntime* runtime = createInferRuntime(gLogger);
+    assert(runtime != nullptr);
+    ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size);
+    assert(engine != nullptr); 
+    IExecutionContext* context = engine->createExecutionContext();
+    assert(context != nullptr);
+    delete[] trtModelStream;
+    auto out_dims = engine->getBindingDimensions(1);
+    auto output_size = 1;
+    for(int j=0;j<out_dims.nbDims;j++) {
+        output_size *= out_dims.d[j];
+    }
+    static float* prob = new float[output_size];
+
+    cv::Mat img = cv::imread(input_image_path);
+    int img_w = img.cols;
+    int img_h = img.rows;
+    cv::Mat pr_img = static_resize(img);
+    std::cout << "blob image" << std::endl;
+
+    float* blob;
+    blob = blobFromImage(pr_img);
+    float scale = std::min(INPUT_W / (img.cols*1.0), INPUT_H / (img.rows*1.0));
+
+    // run inference
+    auto start = std::chrono::system_clock::now();
+    doInference(*context, blob, prob, output_size, pr_img.size());
+    auto end = std::chrono::system_clock::now();
+    std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
+
+    std::vector<Object> objects;
+    decode_outputs(prob, objects, scale, img_w, img_h);
+    draw_objects(img, objects, input_image_path);
+    // delete the pointer to the float
+    delete blob;
+    // destroy the engine
+    context->destroy();
+    engine->destroy();
+    runtime->destroy();
+    return 0;
+}
diff --git a/multimodal/YOLOX/demo/TensorRT/python/README.md b/multimodal/YOLOX/demo/TensorRT/python/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..236eeb1265344b68e24616293c96fffee9a17262
--- /dev/null
+++ b/multimodal/YOLOX/demo/TensorRT/python/README.md
@@ -0,0 +1,46 @@
+# YOLOX-TensorRT in Python
+
+This tutorial includes a Python demo for TensorRT.
+
+## Install TensorRT Toolkit
+
+Please follow the [TensorRT Installation Guide](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html) and [torch2trt gitrepo](https://github.com/NVIDIA-AI-IOT/torch2trt) to install TensorRT and torch2trt.
+
+## Convert model
+
+YOLOX models can be easily conveted to TensorRT models using torch2trt
+
+   If you want to convert our model, use the flag -n to specify a model name:
+   ```shell
+   python tools/trt.py -n <YOLOX_MODEL_NAME> -c <YOLOX_CHECKPOINT>
+   ```
+   For example:
+   ```shell
+   python tools/trt.py -n yolox-s -c your_ckpt.pth
+   ```
+   <YOLOX_MODEL_NAME> can be: yolox-nano, yolox-tiny. yolox-s, yolox-m, yolox-l, yolox-x.
+
+   If you want to convert your customized model, use the flag -f to specify you exp file:
+   ```shell
+   python tools/trt.py -f <YOLOX_EXP_FILE> -c <YOLOX_CHECKPOINT>
+   ```
+   For example:
+   ```shell
+   python tools/trt.py -f /path/to/your/yolox/exps/yolox_s.py -c your_ckpt.pth
+   ```
+   *yolox_s.py* can be any exp file modified by you.
+
+The converted model and the serialized engine file (for C++ demo) will be saved on your experiment output dir.  
+
+## Demo
+
+The TensorRT python demo is merged on our pytorch demo file, so you can run the pytorch demo command with ```--trt```.
+
+```shell
+python tools/demo.py image -n yolox-s --trt --save_result
+```
+or
+```shell
+python tools/demo.py image -f exps/default/yolox_s.py --trt --save_result
+```
+
diff --git a/multimodal/YOLOX/demo/ncnn/README.md b/multimodal/YOLOX/demo/ncnn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a607abd8caad31c51749884cc433082202ba01af
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/README.md
@@ -0,0 +1,8 @@
+# YOLOX-ncnn
+
+Compile files of YOLOX object detection base on [ncnn](https://github.com/Tencent/ncnn).  
+YOLOX is included in ncnn now, you could also try building from ncnn, it's better.
+
+## Acknowledgement
+
+* [ncnn](https://github.com/Tencent/ncnn)
diff --git a/multimodal/YOLOX/demo/ncnn/android/README.md b/multimodal/YOLOX/demo/ncnn/android/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2197ffe9a348d20f541d0e664363e07dfaf425ac
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/README.md
@@ -0,0 +1,27 @@
+# YOLOX-Android-ncnn
+
+Andoird app of YOLOX object detection base on [ncnn](https://github.com/Tencent/ncnn)
+
+
+## Tutorial
+
+### Step1
+
+Download ncnn-android-vulkan.zip from [releases of ncnn](https://github.com/Tencent/ncnn/releases). This repo uses
+[20210525 release](https://github.com/Tencent/ncnn/releases/download/20210525/ncnn-20210525-android-vulkan.zip) for building.
+
+### Step2
+
+After downloading, please extract your zip file. Then, there are two ways to finish this step:
+* put your extracted directory into **app/src/main/jni**
+* change the **ncnn_DIR** path in **app/src/main/jni/CMakeLists.txt** to your extracted directory
+
+### Step3
+Download example param and bin file from [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/ESXBH_GSSmFMszWJ6YG2VkQB5cWDfqVWXgk0D996jH0rpQ?e=qzEqUh) or [github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_s_ncnn.tar.gz). Unzip the file to **app/src/main/assets**.
+
+### Step4
+Open this project with Android Studio, build it and enjoy!
+
+## Reference
+
+* [ncnn-android-yolov5](https://github.com/nihui/ncnn-android-yolov5)
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/build.gradle b/multimodal/YOLOX/demo/ncnn/android/app/build.gradle
new file mode 100644
index 0000000000000000000000000000000000000000..72e5ce088e9656749644edddfb7ca5d39f2b67f1
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/build.gradle
@@ -0,0 +1,24 @@
+apply plugin: 'com.android.application'
+
+android {
+    compileSdkVersion 24
+    buildToolsVersion "29.0.2"
+
+    defaultConfig {
+        applicationId "com.megvii.yoloXncnn"
+        archivesBaseName = "$applicationId"
+
+        ndk {
+            moduleName "ncnn"
+            abiFilters "armeabi-v7a", "arm64-v8a"
+        }
+        minSdkVersion 24
+    }
+
+    externalNativeBuild {
+        cmake {
+            version "3.10.2"
+            path file('src/main/jni/CMakeLists.txt')
+        }
+    }
+}
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/AndroidManifest.xml b/multimodal/YOLOX/demo/ncnn/android/app/src/main/AndroidManifest.xml
new file mode 100644
index 0000000000000000000000000000000000000000..f69b9a0f1891adae1bd88df713980f1cfd0d1e92
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/AndroidManifest.xml
@@ -0,0 +1,15 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+      package="com.megvii.yoloXncnn"
+      android:versionCode="1"
+      android:versionName="1.1">
+    <application android:label="@string/app_name" >
+        <activity android:name="MainActivity"
+                  android:label="@string/app_name">
+            <intent-filter>
+                <action android:name="android.intent.action.MAIN" />
+                <category android:name="android.intent.category.LAUNCHER" />
+            </intent-filter>
+        </activity>
+    </application>
+</manifest> 
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/assets/yolox.param b/multimodal/YOLOX/demo/ncnn/android/app/src/main/assets/yolox.param
new file mode 100644
index 0000000000000000000000000000000000000000..f7990f7ae9a71451bf8abb14cfedc74d9cbc38cc
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/assets/yolox.param
@@ -0,0 +1,222 @@
+7767517
+220 250
+Input                    images                   0 1 images
+YoloV5Focus              focus                    1 1 images 503
+Convolution              Conv_41                  1 1 503 877 0=32 1=3 4=1 5=1 6=3456
+Swish                    Mul_43                   1 1 877 507
+Convolution              Conv_44                  1 1 507 880 0=64 1=3 3=2 4=1 5=1 6=18432
+Swish                    Mul_46                   1 1 880 511
+Split                    splitncnn_0              1 2 511 511_splitncnn_0 511_splitncnn_1
+Convolution              Conv_47                  1 1 511_splitncnn_1 883 0=32 1=1 5=1 6=2048
+Swish                    Mul_49                   1 1 883 515
+Split                    splitncnn_1              1 2 515 515_splitncnn_0 515_splitncnn_1
+Convolution              Conv_50                  1 1 511_splitncnn_0 886 0=32 1=1 5=1 6=2048
+Swish                    Mul_52                   1 1 886 519
+Convolution              Conv_53                  1 1 515_splitncnn_1 889 0=32 1=1 5=1 6=1024
+Swish                    Mul_55                   1 1 889 523
+Convolution              Conv_56                  1 1 523 892 0=32 1=3 4=1 5=1 6=9216
+Swish                    Mul_58                   1 1 892 527
+BinaryOp                 Add_59                   2 1 527 515_splitncnn_0 528
+Concat                   Concat_60                2 1 528 519 529
+Convolution              Conv_61                  1 1 529 895 0=64 1=1 5=1 6=4096
+Swish                    Mul_63                   1 1 895 533
+Convolution              Conv_64                  1 1 533 898 0=128 1=3 3=2 4=1 5=1 6=73728
+Swish                    Mul_66                   1 1 898 537
+Split                    splitncnn_2              1 2 537 537_splitncnn_0 537_splitncnn_1
+Convolution              Conv_67                  1 1 537_splitncnn_1 901 0=64 1=1 5=1 6=8192
+Swish                    Mul_69                   1 1 901 541
+Split                    splitncnn_3              1 2 541 541_splitncnn_0 541_splitncnn_1
+Convolution              Conv_70                  1 1 537_splitncnn_0 904 0=64 1=1 5=1 6=8192
+Swish                    Mul_72                   1 1 904 545
+Convolution              Conv_73                  1 1 541_splitncnn_1 907 0=64 1=1 5=1 6=4096
+Swish                    Mul_75                   1 1 907 549
+Convolution              Conv_76                  1 1 549 910 0=64 1=3 4=1 5=1 6=36864
+Swish                    Mul_78                   1 1 910 553
+BinaryOp                 Add_79                   2 1 553 541_splitncnn_0 554
+Split                    splitncnn_4              1 2 554 554_splitncnn_0 554_splitncnn_1
+Convolution              Conv_80                  1 1 554_splitncnn_1 913 0=64 1=1 5=1 6=4096
+Swish                    Mul_82                   1 1 913 558
+Convolution              Conv_83                  1 1 558 916 0=64 1=3 4=1 5=1 6=36864
+Swish                    Mul_85                   1 1 916 562
+BinaryOp                 Add_86                   2 1 562 554_splitncnn_0 563
+Split                    splitncnn_5              1 2 563 563_splitncnn_0 563_splitncnn_1
+Convolution              Conv_87                  1 1 563_splitncnn_1 919 0=64 1=1 5=1 6=4096
+Swish                    Mul_89                   1 1 919 567
+Convolution              Conv_90                  1 1 567 922 0=64 1=3 4=1 5=1 6=36864
+Swish                    Mul_92                   1 1 922 571
+BinaryOp                 Add_93                   2 1 571 563_splitncnn_0 572
+Concat                   Concat_94                2 1 572 545 573
+Convolution              Conv_95                  1 1 573 925 0=128 1=1 5=1 6=16384
+Swish                    Mul_97                   1 1 925 577
+Split                    splitncnn_6              1 2 577 577_splitncnn_0 577_splitncnn_1
+Convolution              Conv_98                  1 1 577_splitncnn_1 928 0=256 1=3 3=2 4=1 5=1 6=294912
+Swish                    Mul_100                  1 1 928 581
+Split                    splitncnn_7              1 2 581 581_splitncnn_0 581_splitncnn_1
+Convolution              Conv_101                 1 1 581_splitncnn_1 931 0=128 1=1 5=1 6=32768
+Swish                    Mul_103                  1 1 931 585
+Split                    splitncnn_8              1 2 585 585_splitncnn_0 585_splitncnn_1
+Convolution              Conv_104                 1 1 581_splitncnn_0 934 0=128 1=1 5=1 6=32768
+Swish                    Mul_106                  1 1 934 589
+Convolution              Conv_107                 1 1 585_splitncnn_1 937 0=128 1=1 5=1 6=16384
+Swish                    Mul_109                  1 1 937 593
+Convolution              Conv_110                 1 1 593 940 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_112                  1 1 940 597
+BinaryOp                 Add_113                  2 1 597 585_splitncnn_0 598
+Split                    splitncnn_9              1 2 598 598_splitncnn_0 598_splitncnn_1
+Convolution              Conv_114                 1 1 598_splitncnn_1 943 0=128 1=1 5=1 6=16384
+Swish                    Mul_116                  1 1 943 602
+Convolution              Conv_117                 1 1 602 946 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_119                  1 1 946 606
+BinaryOp                 Add_120                  2 1 606 598_splitncnn_0 607
+Split                    splitncnn_10             1 2 607 607_splitncnn_0 607_splitncnn_1
+Convolution              Conv_121                 1 1 607_splitncnn_1 949 0=128 1=1 5=1 6=16384
+Swish                    Mul_123                  1 1 949 611
+Convolution              Conv_124                 1 1 611 952 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_126                  1 1 952 615
+BinaryOp                 Add_127                  2 1 615 607_splitncnn_0 616
+Concat                   Concat_128               2 1 616 589 617
+Convolution              Conv_129                 1 1 617 955 0=256 1=1 5=1 6=65536
+Swish                    Mul_131                  1 1 955 621
+Split                    splitncnn_11             1 2 621 621_splitncnn_0 621_splitncnn_1
+Convolution              Conv_132                 1 1 621_splitncnn_1 958 0=512 1=3 3=2 4=1 5=1 6=1179648
+Swish                    Mul_134                  1 1 958 625
+Convolution              Conv_135                 1 1 625 961 0=256 1=1 5=1 6=131072
+Swish                    Mul_137                  1 1 961 629
+Split                    splitncnn_12             1 4 629 629_splitncnn_0 629_splitncnn_1 629_splitncnn_2 629_splitncnn_3
+Pooling                  MaxPool_138              1 1 629_splitncnn_3 630 1=5 3=2 5=1
+Pooling                  MaxPool_139              1 1 629_splitncnn_2 631 1=9 3=4 5=1
+Pooling                  MaxPool_140              1 1 629_splitncnn_1 632 1=13 3=6 5=1
+Concat                   Concat_141               4 1 629_splitncnn_0 630 631 632 633
+Convolution              Conv_142                 1 1 633 964 0=512 1=1 5=1 6=524288
+Swish                    Mul_144                  1 1 964 637
+Split                    splitncnn_13             1 2 637 637_splitncnn_0 637_splitncnn_1
+Convolution              Conv_145                 1 1 637_splitncnn_1 967 0=256 1=1 5=1 6=131072
+Swish                    Mul_147                  1 1 967 641
+Convolution              Conv_148                 1 1 637_splitncnn_0 970 0=256 1=1 5=1 6=131072
+Swish                    Mul_150                  1 1 970 645
+Convolution              Conv_151                 1 1 641 973 0=256 1=1 5=1 6=65536
+Swish                    Mul_153                  1 1 973 649
+Convolution              Conv_154                 1 1 649 976 0=256 1=3 4=1 5=1 6=589824
+Swish                    Mul_156                  1 1 976 653
+Concat                   Concat_157               2 1 653 645 654
+Convolution              Conv_158                 1 1 654 979 0=512 1=1 5=1 6=262144
+Swish                    Mul_160                  1 1 979 658
+Convolution              Conv_161                 1 1 658 982 0=256 1=1 5=1 6=131072
+Swish                    Mul_163                  1 1 982 662
+Split                    splitncnn_14             1 2 662 662_splitncnn_0 662_splitncnn_1
+Interp                   Resize_165               1 1 662_splitncnn_1 667 0=1 1=2.000000e+00 2=2.000000e+00
+Concat                   Concat_166               2 1 667 621_splitncnn_0 668
+Split                    splitncnn_15             1 2 668 668_splitncnn_0 668_splitncnn_1
+Convolution              Conv_167                 1 1 668_splitncnn_1 985 0=128 1=1 5=1 6=65536
+Swish                    Mul_169                  1 1 985 672
+Convolution              Conv_170                 1 1 668_splitncnn_0 988 0=128 1=1 5=1 6=65536
+Swish                    Mul_172                  1 1 988 676
+Convolution              Conv_173                 1 1 672 991 0=128 1=1 5=1 6=16384
+Swish                    Mul_175                  1 1 991 680
+Convolution              Conv_176                 1 1 680 994 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_178                  1 1 994 684
+Concat                   Concat_179               2 1 684 676 685
+Convolution              Conv_180                 1 1 685 997 0=256 1=1 5=1 6=65536
+Swish                    Mul_182                  1 1 997 689
+Convolution              Conv_183                 1 1 689 1000 0=128 1=1 5=1 6=32768
+Swish                    Mul_185                  1 1 1000 693
+Split                    splitncnn_16             1 2 693 693_splitncnn_0 693_splitncnn_1
+Interp                   Resize_187               1 1 693_splitncnn_1 698 0=1 1=2.000000e+00 2=2.000000e+00
+Concat                   Concat_188               2 1 698 577_splitncnn_0 699
+Split                    splitncnn_17             1 2 699 699_splitncnn_0 699_splitncnn_1
+Convolution              Conv_189                 1 1 699_splitncnn_1 1003 0=64 1=1 5=1 6=16384
+Swish                    Mul_191                  1 1 1003 703
+Convolution              Conv_192                 1 1 699_splitncnn_0 1006 0=64 1=1 5=1 6=16384
+Swish                    Mul_194                  1 1 1006 707
+Convolution              Conv_195                 1 1 703 1009 0=64 1=1 5=1 6=4096
+Swish                    Mul_197                  1 1 1009 711
+Convolution              Conv_198                 1 1 711 1012 0=64 1=3 4=1 5=1 6=36864
+Swish                    Mul_200                  1 1 1012 715
+Concat                   Concat_201               2 1 715 707 716
+Convolution              Conv_202                 1 1 716 1015 0=128 1=1 5=1 6=16384
+Swish                    Mul_204                  1 1 1015 720
+Split                    splitncnn_18             1 2 720 720_splitncnn_0 720_splitncnn_1
+Convolution              Conv_205                 1 1 720_splitncnn_1 1018 0=128 1=3 3=2 4=1 5=1 6=147456
+Swish                    Mul_207                  1 1 1018 724
+Concat                   Concat_208               2 1 724 693_splitncnn_0 725
+Split                    splitncnn_19             1 2 725 725_splitncnn_0 725_splitncnn_1
+Convolution              Conv_209                 1 1 725_splitncnn_1 1021 0=128 1=1 5=1 6=32768
+Swish                    Mul_211                  1 1 1021 729
+Convolution              Conv_212                 1 1 725_splitncnn_0 1024 0=128 1=1 5=1 6=32768
+Swish                    Mul_214                  1 1 1024 733
+Convolution              Conv_215                 1 1 729 1027 0=128 1=1 5=1 6=16384
+Swish                    Mul_217                  1 1 1027 737
+Convolution              Conv_218                 1 1 737 1030 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_220                  1 1 1030 741
+Concat                   Concat_221               2 1 741 733 742
+Convolution              Conv_222                 1 1 742 1033 0=256 1=1 5=1 6=65536
+Swish                    Mul_224                  1 1 1033 746
+Split                    splitncnn_20             1 2 746 746_splitncnn_0 746_splitncnn_1
+Convolution              Conv_225                 1 1 746_splitncnn_1 1036 0=256 1=3 3=2 4=1 5=1 6=589824
+Swish                    Mul_227                  1 1 1036 750
+Concat                   Concat_228               2 1 750 662_splitncnn_0 751
+Split                    splitncnn_21             1 2 751 751_splitncnn_0 751_splitncnn_1
+Convolution              Conv_229                 1 1 751_splitncnn_1 1039 0=256 1=1 5=1 6=131072
+Swish                    Mul_231                  1 1 1039 755
+Convolution              Conv_232                 1 1 751_splitncnn_0 1042 0=256 1=1 5=1 6=131072
+Swish                    Mul_234                  1 1 1042 759
+Convolution              Conv_235                 1 1 755 1045 0=256 1=1 5=1 6=65536
+Swish                    Mul_237                  1 1 1045 763
+Convolution              Conv_238                 1 1 763 1048 0=256 1=3 4=1 5=1 6=589824
+Swish                    Mul_240                  1 1 1048 767
+Concat                   Concat_241               2 1 767 759 768
+Convolution              Conv_242                 1 1 768 1051 0=512 1=1 5=1 6=262144
+Swish                    Mul_244                  1 1 1051 772
+Convolution              Conv_245                 1 1 720_splitncnn_0 1054 0=128 1=1 5=1 6=16384
+Swish                    Mul_247                  1 1 1054 776
+Split                    splitncnn_22             1 2 776 776_splitncnn_0 776_splitncnn_1
+Convolution              Conv_248                 1 1 776_splitncnn_1 1057 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_250                  1 1 1057 780
+Convolution              Conv_251                 1 1 780 1060 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_253                  1 1 1060 784
+Convolution              Conv_254                 1 1 784 797 0=80 1=1 5=1 6=10240 9=4
+Convolution              Conv_255                 1 1 776_splitncnn_0 1063 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_257                  1 1 1063 789
+Convolution              Conv_258                 1 1 789 1066 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_260                  1 1 1066 793
+Split                    splitncnn_23             1 2 793 793_splitncnn_0 793_splitncnn_1
+Convolution              Conv_261                 1 1 793_splitncnn_1 794 0=4 1=1 5=1 6=512
+Convolution              Conv_262                 1 1 793_splitncnn_0 796 0=1 1=1 5=1 6=128 9=4
+Concat                   Concat_265               3 1 794 796 797 798
+Convolution              Conv_266                 1 1 746_splitncnn_0 1069 0=128 1=1 5=1 6=32768
+Swish                    Mul_268                  1 1 1069 802
+Split                    splitncnn_24             1 2 802 802_splitncnn_0 802_splitncnn_1
+Convolution              Conv_269                 1 1 802_splitncnn_1 1072 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_271                  1 1 1072 806
+Convolution              Conv_272                 1 1 806 1075 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_274                  1 1 1075 810
+Convolution              Conv_275                 1 1 810 823 0=80 1=1 5=1 6=10240 9=4
+Convolution              Conv_276                 1 1 802_splitncnn_0 1078 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_278                  1 1 1078 815
+Convolution              Conv_279                 1 1 815 1081 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_281                  1 1 1081 819
+Split                    splitncnn_25             1 2 819 819_splitncnn_0 819_splitncnn_1
+Convolution              Conv_282                 1 1 819_splitncnn_1 820 0=4 1=1 5=1 6=512
+Convolution              Conv_283                 1 1 819_splitncnn_0 822 0=1 1=1 5=1 6=128 9=4
+Concat                   Concat_286               3 1 820 822 823 824
+Convolution              Conv_287                 1 1 772 1084 0=128 1=1 5=1 6=65536
+Swish                    Mul_289                  1 1 1084 828
+Split                    splitncnn_26             1 2 828 828_splitncnn_0 828_splitncnn_1
+Convolution              Conv_290                 1 1 828_splitncnn_1 1087 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_292                  1 1 1087 832
+Convolution              Conv_293                 1 1 832 1090 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_295                  1 1 1090 836
+Convolution              Conv_296                 1 1 836 849 0=80 1=1 5=1 6=10240 9=4
+Convolution              Conv_297                 1 1 828_splitncnn_0 1093 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_299                  1 1 1093 841
+Convolution              Conv_300                 1 1 841 1096 0=128 1=3 4=1 5=1 6=147456
+Swish                    Mul_302                  1 1 1096 845
+Split                    splitncnn_27             1 2 845 845_splitncnn_0 845_splitncnn_1
+Convolution              Conv_303                 1 1 845_splitncnn_1 846 0=4 1=1 5=1 6=512
+Convolution              Conv_304                 1 1 845_splitncnn_0 848 0=1 1=1 5=1 6=128 9=4
+Concat                   Concat_307               3 1 846 848 849 850
+Reshape                  Reshape_315              1 1 798 858 0=-1 1=85
+Reshape                  Reshape_323              1 1 824 866 0=-1 1=85
+Reshape                  Reshape_331              1 1 850 874 0=-1 1=85
+Concat                   Concat_332               3 1 858 866 874 875 0=1
+Permute                  Transpose_333            1 1 875 output 0=1
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/MainActivity.java b/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/MainActivity.java
new file mode 100644
index 0000000000000000000000000000000000000000..0f57e4f1297e3d4787d9e859bccc386bd3bbab06
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/MainActivity.java
@@ -0,0 +1,247 @@
+// Some code in this file is based on:
+// https://github.com/nihui/ncnn-android-yolov5/blob/master/app/src/main/java/com/tencent/yolov5ncnn/MainActivity.java
+// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
+// Copyright (C) Megvii, Inc. and its affiliates. All rights reserved.
+
+package com.megvii.yoloXncnn;
+
+import android.app.Activity;
+import android.content.Intent;
+import android.graphics.Bitmap;
+import android.graphics.BitmapFactory;
+import android.graphics.Canvas;
+import android.graphics.Color;
+import android.graphics.Paint;
+import android.media.ExifInterface;
+import android.graphics.Matrix;
+import android.net.Uri;
+import android.os.Bundle;
+import android.util.Log;
+import android.view.View;
+import android.widget.Button;
+import android.widget.ImageView;
+
+import java.io.FileNotFoundException;
+import java.io.InputStream;
+import java.io.IOException;
+
+public class MainActivity extends Activity
+{
+    private static final int SELECT_IMAGE = 1;
+
+    private ImageView imageView;
+    private Bitmap bitmap = null;
+    private Bitmap yourSelectedImage = null;
+
+    private YOLOXncnn yoloX = new YOLOXncnn();
+
+    /** Called when the activity is first created. */
+    @Override
+    public void onCreate(Bundle savedInstanceState)
+    {
+        super.onCreate(savedInstanceState);
+        setContentView(R.layout.main);
+
+        boolean ret_init = yoloX.Init(getAssets());
+        if (!ret_init)
+        {
+            Log.e("MainActivity", "yoloXncnn Init failed");
+        }
+
+        imageView = (ImageView) findViewById(R.id.imageView);
+
+        Button buttonImage = (Button) findViewById(R.id.buttonImage);
+        buttonImage.setOnClickListener(new View.OnClickListener() {
+            @Override
+            public void onClick(View arg0) {
+                Intent i = new Intent(Intent.ACTION_PICK);
+                i.setType("image/*");
+                startActivityForResult(i, SELECT_IMAGE);
+            }
+        });
+
+        Button buttonDetect = (Button) findViewById(R.id.buttonDetect);
+        buttonDetect.setOnClickListener(new View.OnClickListener() {
+            @Override
+            public void onClick(View arg0) {
+                if (yourSelectedImage == null)
+                    return;
+                YOLOXncnn.Obj[] objects = yoloX.Detect(yourSelectedImage, false);
+
+                showObjects(objects);
+            }
+        });
+
+        Button buttonDetectGPU = (Button) findViewById(R.id.buttonDetectGPU);
+        buttonDetectGPU.setOnClickListener(new View.OnClickListener() {
+            @Override
+            public void onClick(View arg0) {
+                if (yourSelectedImage == null)
+                    return;
+
+                YOLOXncnn.Obj[] objects = yoloX.Detect(yourSelectedImage, true);
+
+                showObjects(objects);
+            }
+        });
+    }
+
+    private void showObjects(YOLOXncnn.Obj[] objects)
+    {
+        if (objects == null)
+        {
+            imageView.setImageBitmap(bitmap);
+            return;
+        }
+
+        // draw objects on bitmap
+        Bitmap rgba = bitmap.copy(Bitmap.Config.ARGB_8888, true);
+
+        final int[] colors = new int[] {
+            Color.rgb( 54,  67, 244),
+            Color.rgb( 99,  30, 233),
+            Color.rgb(176,  39, 156),
+            Color.rgb(183,  58, 103),
+            Color.rgb(181,  81,  63),
+            Color.rgb(243, 150,  33),
+            Color.rgb(244, 169,   3),
+            Color.rgb(212, 188,   0),
+            Color.rgb(136, 150,   0),
+            Color.rgb( 80, 175,  76),
+            Color.rgb( 74, 195, 139),
+            Color.rgb( 57, 220, 205),
+            Color.rgb( 59, 235, 255),
+            Color.rgb(  7, 193, 255),
+            Color.rgb(  0, 152, 255),
+            Color.rgb( 34,  87, 255),
+            Color.rgb( 72,  85, 121),
+            Color.rgb(158, 158, 158),
+            Color.rgb(139, 125,  96)
+        };
+
+        Canvas canvas = new Canvas(rgba);
+
+        Paint paint = new Paint();
+        paint.setStyle(Paint.Style.STROKE);
+        paint.setStrokeWidth(4);
+
+        Paint textbgpaint = new Paint();
+        textbgpaint.setColor(Color.WHITE);
+        textbgpaint.setStyle(Paint.Style.FILL);
+
+        Paint textpaint = new Paint();
+        textpaint.setColor(Color.BLACK);
+        textpaint.setTextSize(26);
+        textpaint.setTextAlign(Paint.Align.LEFT);
+
+        for (int i = 0; i < objects.length; i++)
+        {
+            paint.setColor(colors[i % 19]);
+
+            canvas.drawRect(objects[i].x, objects[i].y, objects[i].x + objects[i].w, objects[i].y + objects[i].h, paint);
+
+            // draw filled text inside image
+            {
+                String text = objects[i].label + " = " + String.format("%.1f", objects[i].prob * 100) + "%";
+
+                float text_width = textpaint.measureText(text);
+                float text_height = - textpaint.ascent() + textpaint.descent();
+
+                float x = objects[i].x;
+                float y = objects[i].y - text_height;
+                if (y < 0)
+                    y = 0;
+                if (x + text_width > rgba.getWidth())
+                    x = rgba.getWidth() - text_width;
+
+                canvas.drawRect(x, y, x + text_width, y + text_height, textbgpaint);
+
+                canvas.drawText(text, x, y - textpaint.ascent(), textpaint);
+            }
+        }
+
+        imageView.setImageBitmap(rgba);
+    }
+
+    @Override
+    protected void onActivityResult(int requestCode, int resultCode, Intent data)
+    {
+        super.onActivityResult(requestCode, resultCode, data);
+
+        if (resultCode == RESULT_OK && null != data) {
+            Uri selectedImage = data.getData();
+
+            try
+            {
+                if (requestCode == SELECT_IMAGE) {
+                    bitmap = decodeUri(selectedImage);
+
+                    yourSelectedImage = bitmap.copy(Bitmap.Config.ARGB_8888, true);
+
+                    imageView.setImageBitmap(bitmap);
+                }
+            }
+            catch (FileNotFoundException e)
+            {
+                Log.e("MainActivity", "FileNotFoundException");
+                return;
+            }
+        }
+    }
+
+    private Bitmap decodeUri(Uri selectedImage) throws FileNotFoundException
+    {
+        // Decode image size
+        BitmapFactory.Options o = new BitmapFactory.Options();
+        o.inJustDecodeBounds = true;
+        BitmapFactory.decodeStream(getContentResolver().openInputStream(selectedImage), null, o);
+
+        // The new size we want to scale to
+        final int REQUIRED_SIZE = 640;
+
+        // Find the correct scale value. It should be the power of 2.
+        int width_tmp = o.outWidth, height_tmp = o.outHeight;
+        int scale = 1;
+        while (true) {
+            if (width_tmp / 2 < REQUIRED_SIZE || height_tmp / 2 < REQUIRED_SIZE) {
+                break;
+            }
+            width_tmp /= 2;
+            height_tmp /= 2;
+            scale *= 2;
+        }
+
+        // Decode with inSampleSize
+        BitmapFactory.Options o2 = new BitmapFactory.Options();
+        o2.inSampleSize = scale;
+        Bitmap bitmap = BitmapFactory.decodeStream(getContentResolver().openInputStream(selectedImage), null, o2);
+
+        // Rotate according to EXIF
+        int rotate = 0;
+        try
+        {
+            ExifInterface exif = new ExifInterface(getContentResolver().openInputStream(selectedImage));
+            int orientation = exif.getAttributeInt(ExifInterface.TAG_ORIENTATION, ExifInterface.ORIENTATION_NORMAL);
+            switch (orientation) {
+                case ExifInterface.ORIENTATION_ROTATE_270:
+                    rotate = 270;
+                    break;
+                case ExifInterface.ORIENTATION_ROTATE_180:
+                    rotate = 180;
+                    break;
+                case ExifInterface.ORIENTATION_ROTATE_90:
+                    rotate = 90;
+                    break;
+            }
+        }
+        catch (IOException e)
+        {
+            Log.e("MainActivity", "ExifInterface IOException");
+        }
+
+        Matrix matrix = new Matrix();
+        matrix.postRotate(rotate);
+        return Bitmap.createBitmap(bitmap, 0, 0, bitmap.getWidth(), bitmap.getHeight(), matrix, true);
+    }
+
+}
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/YOLOXncnn.java b/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/YOLOXncnn.java
new file mode 100644
index 0000000000000000000000000000000000000000..212e1c2b881b89c69f27211160df0d2c61a098d8
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/YOLOXncnn.java
@@ -0,0 +1,27 @@
+// Copyright (C) Megvii, Inc. and its affiliates. All rights reserved.
+
+package com.megvii.yoloXncnn;
+
+import android.content.res.AssetManager;
+import android.graphics.Bitmap;
+
+public class YOLOXncnn
+{
+    public native boolean Init(AssetManager mgr);
+
+    public class Obj
+    {
+        public float x;
+        public float y;
+        public float w;
+        public float h;
+        public String label;
+        public float prob;
+    }
+
+    public native Obj[] Detect(Bitmap bitmap, boolean use_gpu);
+
+    static {
+        System.loadLibrary("yoloXncnn");
+    }
+}
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/yoloXncnn.java b/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/yoloXncnn.java
new file mode 100644
index 0000000000000000000000000000000000000000..212e1c2b881b89c69f27211160df0d2c61a098d8
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/java/com/megvii/yoloXncnn/yoloXncnn.java
@@ -0,0 +1,27 @@
+// Copyright (C) Megvii, Inc. and its affiliates. All rights reserved.
+
+package com.megvii.yoloXncnn;
+
+import android.content.res.AssetManager;
+import android.graphics.Bitmap;
+
+public class YOLOXncnn
+{
+    public native boolean Init(AssetManager mgr);
+
+    public class Obj
+    {
+        public float x;
+        public float y;
+        public float w;
+        public float h;
+        public String label;
+        public float prob;
+    }
+
+    public native Obj[] Detect(Bitmap bitmap, boolean use_gpu);
+
+    static {
+        System.loadLibrary("yoloXncnn");
+    }
+}
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/jni/CMakeLists.txt b/multimodal/YOLOX/demo/ncnn/android/app/src/main/jni/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d4b8cd476bc66ce7f4a381dd4299cf391ad67260
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/jni/CMakeLists.txt
@@ -0,0 +1,14 @@
+project(yoloXncnn)
+
+cmake_minimum_required(VERSION 3.4.1)
+
+set(ncnn_DIR ${CMAKE_SOURCE_DIR}/ncnn-20210525-android-vulkan/${ANDROID_ABI}/lib/cmake/ncnn)
+find_package(ncnn REQUIRED)
+
+add_library(yoloXncnn SHARED yoloXncnn_jni.cpp)
+
+target_link_libraries(yoloXncnn
+    ncnn
+
+    jnigraphics
+)
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/jni/yoloXncnn_jni.cpp b/multimodal/YOLOX/demo/ncnn/android/app/src/main/jni/yoloXncnn_jni.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c27867d2be73cd51a02033f6e7a50b7721954db8
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/jni/yoloXncnn_jni.cpp
@@ -0,0 +1,474 @@
+// Some code in this file is based on:
+// https://github.com/nihui/ncnn-android-yolov5/blob/master/app/src/main/jni/yolov5ncnn_jni.cpp
+// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
+// Copyright (C) Megvii, Inc. and its affiliates. All rights reserved.
+
+#include <android/asset_manager_jni.h>
+#include <android/bitmap.h>
+#include <android/log.h>
+
+#include <jni.h>
+
+#include <string>
+#include <vector>
+
+// ncnn
+#include "layer.h"
+#include "net.h"
+#include "benchmark.h"
+
+static ncnn::UnlockedPoolAllocator g_blob_pool_allocator;
+static ncnn::PoolAllocator g_workspace_pool_allocator;
+
+static ncnn::Net yoloX;
+
+class YoloV5Focus : public ncnn::Layer
+{
+public:
+    YoloV5Focus()
+    {
+        one_blob_only = true;
+    }
+
+    virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const
+    {
+        int w = bottom_blob.w;
+        int h = bottom_blob.h;
+        int channels = bottom_blob.c;
+
+        int outw = w / 2;
+        int outh = h / 2;
+        int outc = channels * 4;
+
+        top_blob.create(outw, outh, outc, 4u, 1, opt.blob_allocator);
+        if (top_blob.empty())
+            return -100;
+
+        #pragma omp parallel for num_threads(opt.num_threads)
+        for (int p = 0; p < outc; p++)
+        {
+            const float* ptr = bottom_blob.channel(p % channels).row((p / channels) % 2) + ((p / channels) / 2);
+            float* outptr = top_blob.channel(p);
+
+            for (int i = 0; i < outh; i++)
+            {
+                for (int j = 0; j < outw; j++)
+                {
+                    *outptr = *ptr;
+
+                    outptr += 1;
+                    ptr += 2;
+                }
+
+                ptr += w;
+            }
+        }
+
+        return 0;
+    }
+};
+
+DEFINE_LAYER_CREATOR(YoloV5Focus)
+
+struct Object
+{
+    float x;
+    float y;
+    float w;
+    float h;
+    int label;
+    float prob;
+};
+
+struct GridAndStride
+{
+    int grid0;
+    int grid1;
+    int stride;
+};
+
+static inline float intersection_area(const Object& a, const Object& b)
+{
+    if (a.x > b.x + b.w || a.x + a.w < b.x || a.y > b.y + b.h || a.y + a.h < b.y)
+    {
+        // no intersection
+        return 0.f;
+    }
+
+    float inter_width = std::min(a.x + a.w, b.x + b.w) - std::max(a.x, b.x);
+    float inter_height = std::min(a.y + a.h, b.y + b.h) - std::max(a.y, b.y);
+
+    return inter_width * inter_height;
+}
+
+static void qsort_descent_inplace(std::vector<Object>& faceobjects, int left, int right)
+{
+    int i = left;
+    int j = right;
+    float p = faceobjects[(left + right) / 2].prob;
+
+    while (i <= j)
+    {
+        while (faceobjects[i].prob > p)
+            i++;
+
+        while (faceobjects[j].prob < p)
+            j--;
+
+        if (i <= j)
+        {
+            // swap
+            std::swap(faceobjects[i], faceobjects[j]);
+
+            i++;
+            j--;
+        }
+    }
+
+    #pragma omp parallel sections
+    {
+        #pragma omp section
+        {
+            if (left < j) qsort_descent_inplace(faceobjects, left, j);
+        }
+        #pragma omp section
+        {
+            if (i < right) qsort_descent_inplace(faceobjects, i, right);
+        }
+    }
+}
+
+static void qsort_descent_inplace(std::vector<Object>& faceobjects)
+{
+    if (faceobjects.empty())
+        return;
+
+    qsort_descent_inplace(faceobjects, 0, faceobjects.size() - 1);
+}
+
+static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold)
+{
+    picked.clear();
+
+    const int n = faceobjects.size();
+
+    std::vector<float> areas(n);
+    for (int i = 0; i < n; i++)
+    {
+        areas[i] = faceobjects[i].w * faceobjects[i].h;
+    }
+
+    for (int i = 0; i < n; i++)
+    {
+        const Object& a = faceobjects[i];
+
+        int keep = 1;
+        for (int j = 0; j < (int)picked.size(); j++)
+        {
+            const Object& b = faceobjects[picked[j]];
+
+            // intersection over union
+            float inter_area = intersection_area(a, b);
+            float union_area = areas[i] + areas[picked[j]] - inter_area;
+            // float IoU = inter_area / union_area
+            if (inter_area / union_area > nms_threshold)
+                keep = 0;
+        }
+
+        if (keep)
+            picked.push_back(i);
+    }
+}
+
+static void generate_grids_and_stride(const int target_size, std::vector<int>& strides, std::vector<GridAndStride>& grid_strides)
+{
+    for (auto stride : strides)
+    {
+        int num_grid = target_size / stride;
+        for (int g1 = 0; g1 < num_grid; g1++)
+        {
+            for (int g0 = 0; g0 < num_grid; g0++)
+            {
+                grid_strides.push_back((GridAndStride){g0, g1, stride});
+            }
+        }
+    }
+}
+
+static void generate_yolox_proposals(std::vector<GridAndStride> grid_strides, const ncnn::Mat& feat_blob, float prob_threshold, std::vector<Object>& objects)
+{
+    const int num_grid = feat_blob.h;
+    fprintf(stderr, "output height: %d, width: %d, channels: %d, dims:%d\n", feat_blob.h, feat_blob.w, feat_blob.c, feat_blob.dims);
+
+    const int num_class = feat_blob.w - 5;
+
+    const int num_anchors = grid_strides.size();
+
+    const float* feat_ptr = feat_blob.channel(0);
+    for (int anchor_idx = 0; anchor_idx < num_anchors; anchor_idx++)
+    {
+        const int grid0 = grid_strides[anchor_idx].grid0;
+        const int grid1 = grid_strides[anchor_idx].grid1;
+        const int stride = grid_strides[anchor_idx].stride;
+
+        // yolox/models/yolo_head.py decode logic
+        //  outputs[..., :2] = (outputs[..., :2] + grids) * strides
+        //  outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
+        float x_center = (feat_ptr[0] + grid0) * stride;
+        float y_center = (feat_ptr[1] + grid1) * stride;
+        float w = exp(feat_ptr[2]) * stride;
+        float h = exp(feat_ptr[3]) * stride;
+        float x0 = x_center - w * 0.5f;
+        float y0 = y_center - h * 0.5f;
+
+        float box_objectness = feat_ptr[4];
+        for (int class_idx = 0; class_idx < num_class; class_idx++)
+        {
+            float box_cls_score = feat_ptr[5 + class_idx];
+            float box_prob = box_objectness * box_cls_score;
+            if (box_prob > prob_threshold)
+            {
+                Object obj;
+                obj.x = x0;
+                obj.y = y0;
+                obj.w = w;
+                obj.h = h;
+                obj.label = class_idx;
+                obj.prob = box_prob;
+
+                objects.push_back(obj);
+            }
+
+        } // class loop
+        feat_ptr += feat_blob.w;
+
+    } // point anchor loop
+}
+
+
+extern "C" {
+
+// FIXME DeleteGlobalRef is missing for objCls
+static jclass objCls = NULL;
+static jmethodID constructortorId;
+static jfieldID xId;
+static jfieldID yId;
+static jfieldID wId;
+static jfieldID hId;
+static jfieldID labelId;
+static jfieldID probId;
+
+JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void* reserved)
+{
+    __android_log_print(ANDROID_LOG_DEBUG, "YOLOXncnn", "JNI_OnLoad");
+
+    ncnn::create_gpu_instance();
+
+    return JNI_VERSION_1_4;
+}
+
+JNIEXPORT void JNI_OnUnload(JavaVM* vm, void* reserved)
+{
+    __android_log_print(ANDROID_LOG_DEBUG, "YOLOXncnn", "JNI_OnUnload");
+
+    ncnn::destroy_gpu_instance();
+}
+
+// public native boolean Init(AssetManager mgr);
+JNIEXPORT jboolean JNICALL Java_com_megvii_yoloXncnn_YOLOXncnn_Init(JNIEnv* env, jobject thiz, jobject assetManager)
+{
+    ncnn::Option opt;
+    opt.lightmode = true;
+    opt.num_threads = 4;
+    opt.blob_allocator = &g_blob_pool_allocator;
+    opt.workspace_allocator = &g_workspace_pool_allocator;
+    opt.use_packing_layout = true;
+
+    // use vulkan compute
+    if (ncnn::get_gpu_count() != 0)
+        opt.use_vulkan_compute = true;
+
+    AAssetManager* mgr = AAssetManager_fromJava(env, assetManager);
+
+    yoloX.opt = opt;
+
+    yoloX.register_custom_layer("YoloV5Focus", YoloV5Focus_layer_creator);
+
+    // init param
+    {
+        int ret = yoloX.load_param(mgr, "yolox.param");
+        if (ret != 0)
+        {
+            __android_log_print(ANDROID_LOG_DEBUG, "YOLOXncnn", "load_param failed");
+            return JNI_FALSE;
+        }
+    }
+
+    // init bin
+    {
+        int ret = yoloX.load_model(mgr, "yolox.bin");
+        if (ret != 0)
+        {
+            __android_log_print(ANDROID_LOG_DEBUG, "YOLOXncnn", "load_model failed");
+            return JNI_FALSE;
+        }
+    }
+
+    // init jni glue
+    jclass localObjCls = env->FindClass("com/megvii/yoloXncnn/YOLOXncnn$Obj");
+    objCls = reinterpret_cast<jclass>(env->NewGlobalRef(localObjCls));
+
+    constructortorId = env->GetMethodID(objCls, "<init>", "(Lcom/megvii/yoloXncnn/YOLOXncnn;)V");
+
+    xId = env->GetFieldID(objCls, "x", "F");
+    yId = env->GetFieldID(objCls, "y", "F");
+    wId = env->GetFieldID(objCls, "w", "F");
+    hId = env->GetFieldID(objCls, "h", "F");
+    labelId = env->GetFieldID(objCls, "label", "Ljava/lang/String;");
+    probId = env->GetFieldID(objCls, "prob", "F");
+
+    return JNI_TRUE;
+}
+
+// public native Obj[] Detect(Bitmap bitmap, boolean use_gpu);
+JNIEXPORT jobjectArray JNICALL Java_com_megvii_yoloXncnn_YOLOXncnn_Detect(JNIEnv* env, jobject thiz, jobject bitmap, jboolean use_gpu)
+{
+    if (use_gpu == JNI_TRUE && ncnn::get_gpu_count() == 0)
+    {
+        return NULL;
+        //return env->NewStringUTF("no vulkan capable gpu");
+    }
+
+    double start_time = ncnn::get_current_time();
+
+    AndroidBitmapInfo info;
+    AndroidBitmap_getInfo(env, bitmap, &info);
+    const int width = info.width;
+    const int height = info.height;
+    if (info.format != ANDROID_BITMAP_FORMAT_RGBA_8888)
+        return NULL;
+
+    // parameters which might change for different model
+    const int target_size = 640;
+    const float prob_threshold = 0.3f;
+    const float nms_threshold = 0.65f;
+    std::vector<int> strides = {8, 16, 32}; // might have stride=64
+
+    int w = width;
+    int h = height;
+    float scale = 1.f;
+    if (w > h)
+    {
+        scale = (float)target_size / w;
+        w = target_size;
+        h = h * scale;
+    }
+    else
+    {
+        scale = (float)target_size / h;
+        h = target_size;
+        w = w * scale;
+    }
+
+    ncnn::Mat in = ncnn::Mat::from_android_bitmap_resize(env, bitmap, ncnn::Mat::PIXEL_RGB2BGR, w, h);
+
+    // pad to target_size rectangle
+    int wpad = target_size - w;
+    int hpad = target_size - h;
+    ncnn::Mat in_pad;
+    // different from yolov5, yolox only pad on bottom and right side,
+    // which means users don't need to extra padding info to decode boxes coordinate.
+    ncnn::copy_make_border(in, in_pad, 0, hpad, 0, wpad, ncnn::BORDER_CONSTANT, 114.f);
+
+    // yolox
+    std::vector<Object> objects;
+    {
+
+        ncnn::Extractor ex = yoloX.create_extractor();
+
+        ex.set_vulkan_compute(use_gpu);
+
+        ex.input("images", in_pad);
+
+        std::vector<Object> proposals;
+
+        // yolox decode and generate proposal logic
+        {
+            ncnn::Mat out;
+            ex.extract("output", out);
+
+            std::vector<GridAndStride> grid_strides;
+            generate_grids_and_stride(target_size, strides, grid_strides);
+            generate_yolox_proposals(grid_strides, out, prob_threshold, proposals);
+
+        }
+
+        // sort all proposals by score from highest to lowest
+        qsort_descent_inplace(proposals);
+
+        // apply nms with nms_threshold
+        std::vector<int> picked;
+        nms_sorted_bboxes(proposals, picked, nms_threshold);
+
+        int count = picked.size();
+
+        objects.resize(count);
+        for (int i = 0; i < count; i++)
+        {
+            objects[i] = proposals[picked[i]];
+
+            // adjust offset to original unpadded
+            float x0 = (objects[i].x) / scale;
+            float y0 = (objects[i].y) / scale;
+            float x1 = (objects[i].x + objects[i].w) / scale;
+            float y1 = (objects[i].y + objects[i].h) / scale;
+
+            // clip
+            x0 = std::max(std::min(x0, (float)(width - 1)), 0.f);
+            y0 = std::max(std::min(y0, (float)(height - 1)), 0.f);
+            x1 = std::max(std::min(x1, (float)(width - 1)), 0.f);
+            y1 = std::max(std::min(y1, (float)(height - 1)), 0.f);
+
+            objects[i].x = x0;
+            objects[i].y = y0;
+            objects[i].w = x1 - x0;
+            objects[i].h = y1 - y0;
+        }
+    }
+
+    // objects to Obj[]
+    static const char* class_names[] = {
+        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
+        "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
+        "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
+        "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
+        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
+        "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
+        "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
+        "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
+        "hair drier", "toothbrush"
+    };
+
+    jobjectArray jObjArray = env->NewObjectArray(objects.size(), objCls, NULL);
+
+    for (size_t i=0; i<objects.size(); i++)
+    {
+        jobject jObj = env->NewObject(objCls, constructortorId, thiz);
+
+        env->SetFloatField(jObj, xId, objects[i].x);
+        env->SetFloatField(jObj, yId, objects[i].y);
+        env->SetFloatField(jObj, wId, objects[i].w);
+        env->SetFloatField(jObj, hId, objects[i].h);
+        env->SetObjectField(jObj, labelId, env->NewStringUTF(class_names[objects[i].label]));
+        env->SetFloatField(jObj, probId, objects[i].prob);
+
+        env->SetObjectArrayElement(jObjArray, i, jObj);
+    }
+
+    double elasped = ncnn::get_current_time() - start_time;
+    __android_log_print(ANDROID_LOG_DEBUG, "YOLOXncnn", "%.2fms   detect", elasped);
+
+    return jObjArray;
+}
+
+}
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/res/layout/main.xml b/multimodal/YOLOX/demo/ncnn/android/app/src/main/res/layout/main.xml
new file mode 100644
index 0000000000000000000000000000000000000000..9440a1fdf222f7be484cb6511c200ea1c0eaae9b
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/res/layout/main.xml
@@ -0,0 +1,35 @@
+<?xml version="1.0" encoding="utf-8"?>
+<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
+    android:orientation="vertical"
+    android:layout_width="fill_parent"
+    android:layout_height="fill_parent">
+
+    <LinearLayout
+        android:orientation="horizontal"
+        android:layout_width="fill_parent"
+        android:layout_height="wrap_content">
+
+    <Button
+        android:id="@+id/buttonImage"
+        android:layout_width="wrap_content"
+        android:layout_height="wrap_content"
+        android:text="image" />
+    <Button
+        android:id="@+id/buttonDetect"
+        android:layout_width="wrap_content"
+        android:layout_height="wrap_content"
+        android:text="infer-cpu" />
+    <Button
+        android:id="@+id/buttonDetectGPU"
+        android:layout_width="wrap_content"
+        android:layout_height="wrap_content"
+        android:text="infer-gpu" />
+    </LinearLayout>
+
+    <ImageView
+        android:id="@+id/imageView"
+        android:layout_width="fill_parent"
+        android:layout_height="fill_parent"
+        android:layout_weight="1" />
+
+</LinearLayout>
diff --git a/multimodal/YOLOX/demo/ncnn/android/app/src/main/res/values/strings.xml b/multimodal/YOLOX/demo/ncnn/android/app/src/main/res/values/strings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..853a8a084f8c8dd532131df0de8c39eba7b11a9d
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/app/src/main/res/values/strings.xml
@@ -0,0 +1,4 @@
+<?xml version="1.0" encoding="utf-8"?>
+<resources>
+    <string name="app_name">yoloXncnn</string>
+</resources>
diff --git a/multimodal/YOLOX/demo/ncnn/android/build.gradle b/multimodal/YOLOX/demo/ncnn/android/build.gradle
new file mode 100644
index 0000000000000000000000000000000000000000..88031d6a3a28f01e5f46d46fd6ce82453c834061
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/build.gradle
@@ -0,0 +1,17 @@
+// Top-level build file where you can add configuration options common to all sub-projects/modules.
+buildscript {
+    repositories {
+        jcenter()
+        google()
+    }
+    dependencies {
+        classpath 'com.android.tools.build:gradle:3.5.0'
+    }
+}
+
+allprojects {
+    repositories {
+        jcenter()
+        google()
+    }
+}
diff --git a/multimodal/YOLOX/demo/ncnn/android/gradle/wrapper/gradle-wrapper.jar b/multimodal/YOLOX/demo/ncnn/android/gradle/wrapper/gradle-wrapper.jar
new file mode 100644
index 0000000000000000000000000000000000000000..f6b961fd5a86aa5fbfe90f707c3138408be7c718
Binary files /dev/null and b/multimodal/YOLOX/demo/ncnn/android/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/multimodal/YOLOX/demo/ncnn/android/gradle/wrapper/gradle-wrapper.properties b/multimodal/YOLOX/demo/ncnn/android/gradle/wrapper/gradle-wrapper.properties
new file mode 100644
index 0000000000000000000000000000000000000000..866fd2b81c5e9df5a9e4ae6c6cbd27a979c2484e
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/gradle/wrapper/gradle-wrapper.properties
@@ -0,0 +1,6 @@
+#Sun Aug 25 10:34:48 CST 2019
+distributionBase=GRADLE_USER_HOME
+distributionPath=wrapper/dists
+zipStoreBase=GRADLE_USER_HOME
+zipStorePath=wrapper/dists
+distributionUrl=https\://services.gradle.org/distributions/gradle-5.4.1-all.zip
diff --git a/multimodal/YOLOX/demo/ncnn/android/gradlew b/multimodal/YOLOX/demo/ncnn/android/gradlew
new file mode 100755
index 0000000000000000000000000000000000000000..cccdd3d517fc5249beaefa600691cf150f2fa3e6
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/gradlew
@@ -0,0 +1,172 @@
+#!/usr/bin/env sh
+
+##############################################################################
+##
+##  Gradle start up script for UN*X
+##
+##############################################################################
+
+# Attempt to set APP_HOME
+# Resolve links: $0 may be a link
+PRG="$0"
+# Need this for relative symlinks.
+while [ -h "$PRG" ] ; do
+    ls=`ls -ld "$PRG"`
+    link=`expr "$ls" : '.*-> \(.*\)$'`
+    if expr "$link" : '/.*' > /dev/null; then
+        PRG="$link"
+    else
+        PRG=`dirname "$PRG"`"/$link"
+    fi
+done
+SAVED="`pwd`"
+cd "`dirname \"$PRG\"`/" >/dev/null
+APP_HOME="`pwd -P`"
+cd "$SAVED" >/dev/null
+
+APP_NAME="Gradle"
+APP_BASE_NAME=`basename "$0"`
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS=""
+
+# Use the maximum available, or set MAX_FD != -1 to use that value.
+MAX_FD="maximum"
+
+warn () {
+    echo "$*"
+}
+
+die () {
+    echo
+    echo "$*"
+    echo
+    exit 1
+}
+
+# OS specific support (must be 'true' or 'false').
+cygwin=false
+msys=false
+darwin=false
+nonstop=false
+case "`uname`" in
+  CYGWIN* )
+    cygwin=true
+    ;;
+  Darwin* )
+    darwin=true
+    ;;
+  MINGW* )
+    msys=true
+    ;;
+  NONSTOP* )
+    nonstop=true
+    ;;
+esac
+
+CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
+
+# Determine the Java command to use to start the JVM.
+if [ -n "$JAVA_HOME" ] ; then
+    if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
+        # IBM's JDK on AIX uses strange locations for the executables
+        JAVACMD="$JAVA_HOME/jre/sh/java"
+    else
+        JAVACMD="$JAVA_HOME/bin/java"
+    fi
+    if [ ! -x "$JAVACMD" ] ; then
+        die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+    fi
+else
+    JAVACMD="java"
+    which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+
+Please set the JAVA_HOME variable in your environment to match the
+location of your Java installation."
+fi
+
+# Increase the maximum file descriptors if we can.
+if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
+    MAX_FD_LIMIT=`ulimit -H -n`
+    if [ $? -eq 0 ] ; then
+        if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
+            MAX_FD="$MAX_FD_LIMIT"
+        fi
+        ulimit -n $MAX_FD
+        if [ $? -ne 0 ] ; then
+            warn "Could not set maximum file descriptor limit: $MAX_FD"
+        fi
+    else
+        warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
+    fi
+fi
+
+# For Darwin, add options to specify how the application appears in the dock
+if $darwin; then
+    GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
+fi
+
+# For Cygwin, switch paths to Windows format before running java
+if $cygwin ; then
+    APP_HOME=`cygpath --path --mixed "$APP_HOME"`
+    CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
+    JAVACMD=`cygpath --unix "$JAVACMD"`
+
+    # We build the pattern for arguments to be converted via cygpath
+    ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
+    SEP=""
+    for dir in $ROOTDIRSRAW ; do
+        ROOTDIRS="$ROOTDIRS$SEP$dir"
+        SEP="|"
+    done
+    OURCYGPATTERN="(^($ROOTDIRS))"
+    # Add a user-defined pattern to the cygpath arguments
+    if [ "$GRADLE_CYGPATTERN" != "" ] ; then
+        OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
+    fi
+    # Now convert the arguments - kludge to limit ourselves to /bin/sh
+    i=0
+    for arg in "$@" ; do
+        CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
+        CHECK2=`echo "$arg"|egrep -c "^-"`                                 ### Determine if an option
+
+        if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then                    ### Added a condition
+            eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
+        else
+            eval `echo args$i`="\"$arg\""
+        fi
+        i=$((i+1))
+    done
+    case $i in
+        (0) set -- ;;
+        (1) set -- "$args0" ;;
+        (2) set -- "$args0" "$args1" ;;
+        (3) set -- "$args0" "$args1" "$args2" ;;
+        (4) set -- "$args0" "$args1" "$args2" "$args3" ;;
+        (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
+        (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
+        (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
+        (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
+        (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
+    esac
+fi
+
+# Escape application args
+save () {
+    for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
+    echo " "
+}
+APP_ARGS=$(save "$@")
+
+# Collect all arguments for the java command, following the shell quoting and substitution rules
+eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
+
+# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong
+if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then
+  cd "$(dirname "$0")"
+fi
+
+exec "$JAVACMD" "$@"
diff --git a/multimodal/YOLOX/demo/ncnn/android/gradlew.bat b/multimodal/YOLOX/demo/ncnn/android/gradlew.bat
new file mode 100644
index 0000000000000000000000000000000000000000..f9553162f122c71b34635112e717c3e733b5b212
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/gradlew.bat
@@ -0,0 +1,84 @@
+@if "%DEBUG%" == "" @echo off
+@rem ##########################################################################
+@rem
+@rem  Gradle startup script for Windows
+@rem
+@rem ##########################################################################
+
+@rem Set local scope for the variables with windows NT shell
+if "%OS%"=="Windows_NT" setlocal
+
+set DIRNAME=%~dp0
+if "%DIRNAME%" == "" set DIRNAME=.
+set APP_BASE_NAME=%~n0
+set APP_HOME=%DIRNAME%
+
+@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+set DEFAULT_JVM_OPTS=
+
+@rem Find java.exe
+if defined JAVA_HOME goto findJavaFromJavaHome
+
+set JAVA_EXE=java.exe
+%JAVA_EXE% -version >NUL 2>&1
+if "%ERRORLEVEL%" == "0" goto init
+
+echo.
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:findJavaFromJavaHome
+set JAVA_HOME=%JAVA_HOME:"=%
+set JAVA_EXE=%JAVA_HOME%/bin/java.exe
+
+if exist "%JAVA_EXE%" goto init
+
+echo.
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
+echo.
+echo Please set the JAVA_HOME variable in your environment to match the
+echo location of your Java installation.
+
+goto fail
+
+:init
+@rem Get command-line arguments, handling Windows variants
+
+if not "%OS%" == "Windows_NT" goto win9xME_args
+
+:win9xME_args
+@rem Slurp the command line arguments.
+set CMD_LINE_ARGS=
+set _SKIP=2
+
+:win9xME_args_slurp
+if "x%~1" == "x" goto execute
+
+set CMD_LINE_ARGS=%*
+
+:execute
+@rem Setup the command line
+
+set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
+
+@rem Execute Gradle
+"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS%
+
+:end
+@rem End local scope for the variables with windows NT shell
+if "%ERRORLEVEL%"=="0" goto mainEnd
+
+:fail
+rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
+rem the _cmd.exe /c_ return code!
+if  not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
+exit /b 1
+
+:mainEnd
+if "%OS%"=="Windows_NT" endlocal
+
+:omega
diff --git a/multimodal/YOLOX/demo/ncnn/android/settings.gradle b/multimodal/YOLOX/demo/ncnn/android/settings.gradle
new file mode 100644
index 0000000000000000000000000000000000000000..e7b4def49cb53d9aa04228dd3edb14c9e635e003
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/android/settings.gradle
@@ -0,0 +1 @@
+include ':app'
diff --git a/multimodal/YOLOX/demo/ncnn/cpp/README.md b/multimodal/YOLOX/demo/ncnn/cpp/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c3fe79196b57e722d075d0855af469c082637663
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/cpp/README.md
@@ -0,0 +1,104 @@
+# YOLOX-CPP-ncnn
+
+Cpp file compile of YOLOX object detection base on [ncnn](https://github.com/Tencent/ncnn).  
+
+## Tutorial
+
+### Step1
+Clone [ncnn](https://github.com/Tencent/ncnn) first, then please following [build tutorial of ncnn](https://github.com/Tencent/ncnn/wiki/how-to-build) to build on your own device.
+
+### Step2
+First, we try the original onnx2ncnn solution by using provided tools to generate onnx file.
+For example, if you want to generate onnx file of yolox-s, please run the following command:
+```shell
+cd <path of yolox>
+python3 tools/export_onnx.py -n yolox-s
+```
+Then a yolox.onnx file is generated.
+
+### Step3
+Generate ncnn param and bin file.
+```shell
+cd <path of ncnn>
+cd build/tools/ncnn
+./onnx2ncnn yolox.onnx model.param model.bin
+```
+
+Since Focus module is not supported in ncnn. You will see warnings like:
+```shell
+Unsupported slice step!
+```
+However, don't worry on this as a C++ version of Focus layer is already implemented in yolox.cpp.
+
+### Step4
+Open **model.param**, and modify it. For more information on the ncnn param and model file structure, please take a look at this [wiki](https://github.com/Tencent/ncnn/wiki/param-and-model-file-structure).
+
+Before (just an example):
+```
+295 328
+Input            images                   0 1 images
+Split            splitncnn_input0         1 4 images images_splitncnn_0 images_splitncnn_1 images_splitncnn_2 images_splitncnn_3
+Crop             Slice_4                  1 1 images_splitncnn_3 647 -23309=1,0 -23310=1,2147483647 -23311=1,1
+Crop             Slice_9                  1 1 647 652 -23309=1,0 -23310=1,2147483647 -23311=1,2
+Crop             Slice_14                 1 1 images_splitncnn_2 657 -23309=1,0 -23310=1,2147483647 -23311=1,1
+Crop             Slice_19                 1 1 657 662 -23309=1,1 -23310=1,2147483647 -23311=1,2
+Crop             Slice_24                 1 1 images_splitncnn_1 667 -23309=1,1 -23310=1,2147483647 -23311=1,1
+Crop             Slice_29                 1 1 667 672 -23309=1,0 -23310=1,2147483647 -23311=1,2
+Crop             Slice_34                 1 1 images_splitncnn_0 677 -23309=1,1 -23310=1,2147483647 -23311=1,1
+Crop             Slice_39                 1 1 677 682 -23309=1,1 -23310=1,2147483647 -23311=1,2
+Concat           Concat_40                4 1 652 672 662 682 683 0=0
+...
+```
+* Change first number for 295 to 295 - 9 = 286 (since we will remove 10 layers and add 1 layers, total layers number should minus 9). 
+* Then remove 10 lines of code from Split to Concat, but remember the last but 2nd number: 683.
+* Add YoloV5Focus layer After Input (using previous number 683):
+```
+YoloV5Focus      focus                    1 1 images 683
+```
+After(just an example):
+```
+286 328
+Input            images                   0 1 images
+YoloV5Focus      focus                    1 1 images 683
+...
+```
+
+### Step5
+Use ncnn_optimize to generate new param and bin:
+```shell
+# suppose you are still under ncnn/build/tools/ncnn dir.
+../ncnnoptimize model.param model.bin yolox.param yolox.bin 65536
+```
+
+### Step6
+Copy or Move yolox.cpp file into ncnn/examples, modify the CMakeList.txt to add our implementation, then build.
+
+### Step7
+Inference image with executable file yolox, enjoy the detect result:
+```shell
+./yolox demo.jpg
+```
+
+### Bounus Solution:
+As ncnn has released another model conversion tool called [pnnx](https://zhuanlan.zhihu.com/p/427620428) which directly finishs the pytorch2ncnn process via torchscript, we can also try on this.
+
+```shell
+# take yolox-s as an example
+python3 tools/export_torchscript.py -n yolox-s -c /path/to/your_checkpoint_files
+```
+Then a `yolox.torchscript.pt` will be generated. Copy this file to your pnnx build directory (pnnx also provides pre-built packages [here](https://github.com/pnnx/pnnx/releases/tag/20220720)).
+
+```shell
+# suppose you put the yolox.torchscript.pt in a seperate folder
+./pnnx yolox/yolox.torchscript.pt inputshape=[1,3,640,640]
+# for zsh users, please use inputshape='[1,3,640,640]'
+```
+Still, as ncnn does not support `slice` op as we mentioned in [Step3](https://github.com/Megvii-BaseDetection/YOLOX/tree/main/demo/ncnn/cpp#step3). You will still see the warnings during this process.
+
+Then multiple pnnx related files will be genreated in your yolox folder. Use `yolox.torchscript.ncnn.param` and `yolox.torchscript.ncnn.bin` as your converted model. 
+
+Then we can follow back to our [Step4](https://github.com/Megvii-BaseDetection/YOLOX/tree/main/demo/ncnn/cpp#step4) for the rest of our implementation.
+
+## Acknowledgement
+
+* [ncnn](https://github.com/Tencent/ncnn)
diff --git a/multimodal/YOLOX/demo/ncnn/cpp/yolox.cpp b/multimodal/YOLOX/demo/ncnn/cpp/yolox.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..0f1a08083cfc6c649e62cba5b2ecf75232af3865
--- /dev/null
+++ b/multimodal/YOLOX/demo/ncnn/cpp/yolox.cpp
@@ -0,0 +1,416 @@
+// This file is wirtten base on the following file:
+// https://github.com/Tencent/ncnn/blob/master/examples/yolov5.cpp
+// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
+// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
+// in compliance with the License. You may obtain a copy of the License at
+//
+// https://opensource.org/licenses/BSD-3-Clause
+//
+// Unless required by applicable law or agreed to in writing, software distributed
+// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
+// CONDITIONS OF ANY KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations under the License.
+// ------------------------------------------------------------------------------
+// Copyright (C) 2020-2021, Megvii Inc. All rights reserved.
+
+#include "layer.h"
+#include "net.h"
+
+#if defined(USE_NCNN_SIMPLEOCV)
+#include "simpleocv.h"
+#else
+#include <opencv2/core/core.hpp>
+#include <opencv2/highgui/highgui.hpp>
+#include <opencv2/imgproc/imgproc.hpp>
+#endif
+#include <float.h>
+#include <stdio.h>
+#include <vector>
+
+#define YOLOX_NMS_THRESH  0.45 // nms threshold
+#define YOLOX_CONF_THRESH 0.25 // threshold of bounding box prob
+#define YOLOX_TARGET_SIZE 640  // target image size after resize, might use 416 for small model
+
+// YOLOX use the same focus in yolov5
+class YoloV5Focus : public ncnn::Layer
+{
+public:
+    YoloV5Focus()
+    {
+        one_blob_only = true;
+    }
+
+    virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const
+    {
+        int w = bottom_blob.w;
+        int h = bottom_blob.h;
+        int channels = bottom_blob.c;
+
+        int outw = w / 2;
+        int outh = h / 2;
+        int outc = channels * 4;
+
+        top_blob.create(outw, outh, outc, 4u, 1, opt.blob_allocator);
+        if (top_blob.empty())
+            return -100;
+
+        #pragma omp parallel for num_threads(opt.num_threads)
+        for (int p = 0; p < outc; p++)
+        {
+            const float* ptr = bottom_blob.channel(p % channels).row((p / channels) % 2) + ((p / channels) / 2);
+            float* outptr = top_blob.channel(p);
+
+            for (int i = 0; i < outh; i++)
+            {
+                for (int j = 0; j < outw; j++)
+                {
+                    *outptr = *ptr;
+
+                    outptr += 1;
+                    ptr += 2;
+                }
+
+                ptr += w;
+            }
+        }
+
+        return 0;
+    }
+};
+
+DEFINE_LAYER_CREATOR(YoloV5Focus)
+
+struct Object
+{
+    cv::Rect_<float> rect;
+    int label;
+    float prob;
+};
+
+struct GridAndStride
+{
+    int grid0;
+    int grid1;
+    int stride;
+};
+
+static inline float intersection_area(const Object& a, const Object& b)
+{
+    cv::Rect_<float> inter = a.rect & b.rect;
+    return inter.area();
+}
+
+static void qsort_descent_inplace(std::vector<Object>& faceobjects, int left, int right)
+{
+    int i = left;
+    int j = right;
+    float p = faceobjects[(left + right) / 2].prob;
+
+    while (i <= j)
+    {
+        while (faceobjects[i].prob > p)
+            i++;
+
+        while (faceobjects[j].prob < p)
+            j--;
+
+        if (i <= j)
+        {
+            // swap
+            std::swap(faceobjects[i], faceobjects[j]);
+
+            i++;
+            j--;
+        }
+    }
+
+    #pragma omp parallel sections
+    {
+        #pragma omp section
+        {
+            if (left < j) qsort_descent_inplace(faceobjects, left, j);
+        }
+        #pragma omp section
+        {
+            if (i < right) qsort_descent_inplace(faceobjects, i, right);
+        }
+    }
+}
+
+static void qsort_descent_inplace(std::vector<Object>& objects)
+{
+    if (objects.empty())
+        return;
+
+    qsort_descent_inplace(objects, 0, objects.size() - 1);
+}
+
+static void nms_sorted_bboxes(const std::vector<Object>& faceobjects, std::vector<int>& picked, float nms_threshold)
+{
+    picked.clear();
+
+    const int n = faceobjects.size();
+
+    std::vector<float> areas(n);
+    for (int i = 0; i < n; i++)
+    {
+        areas[i] = faceobjects[i].rect.area();
+    }
+
+    for (int i = 0; i < n; i++)
+    {
+        const Object& a = faceobjects[i];
+
+        int keep = 1;
+        for (int j = 0; j < (int)picked.size(); j++)
+        {
+            const Object& b = faceobjects[picked[j]];
+
+            // intersection over union
+            float inter_area = intersection_area(a, b);
+            float union_area = areas[i] + areas[picked[j]] - inter_area;
+            // float IoU = inter_area / union_area
+            if (inter_area / union_area > nms_threshold)
+                keep = 0;
+        }
+
+        if (keep)
+            picked.push_back(i);
+    }
+}
+
+static void generate_grids_and_stride(const int target_size, std::vector<int>& strides, std::vector<GridAndStride>& grid_strides)
+{
+    for (int i = 0; i < (int)strides.size(); i++)
+    {
+        int stride = strides[i];
+        int num_grid = target_size / stride;
+        for (int g1 = 0; g1 < num_grid; g1++)
+        {
+            for (int g0 = 0; g0 < num_grid; g0++)
+            {
+                GridAndStride gs;
+                gs.grid0 = g0;
+                gs.grid1 = g1;
+                gs.stride = stride;
+                grid_strides.push_back(gs);
+            }
+        }
+    }
+}
+
+static void generate_yolox_proposals(std::vector<GridAndStride> grid_strides, const ncnn::Mat& feat_blob, float prob_threshold, std::vector<Object>& objects)
+{
+    const int num_grid = feat_blob.h;
+    const int num_class = feat_blob.w - 5;
+    const int num_anchors = grid_strides.size();
+
+    const float* feat_ptr = feat_blob.channel(0);
+    for (int anchor_idx = 0; anchor_idx < num_anchors; anchor_idx++)
+    {
+        const int grid0 = grid_strides[anchor_idx].grid0;
+        const int grid1 = grid_strides[anchor_idx].grid1;
+        const int stride = grid_strides[anchor_idx].stride;
+
+        // yolox/models/yolo_head.py decode logic
+        //  outputs[..., :2] = (outputs[..., :2] + grids) * strides
+        //  outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides
+        float x_center = (feat_ptr[0] + grid0) * stride;
+        float y_center = (feat_ptr[1] + grid1) * stride;
+        float w = exp(feat_ptr[2]) * stride;
+        float h = exp(feat_ptr[3]) * stride;
+        float x0 = x_center - w * 0.5f;
+        float y0 = y_center - h * 0.5f;
+
+        float box_objectness = feat_ptr[4];
+        for (int class_idx = 0; class_idx < num_class; class_idx++)
+        {
+            float box_cls_score = feat_ptr[5 + class_idx];
+            float box_prob = box_objectness * box_cls_score;
+            if (box_prob > prob_threshold)
+            {
+                Object obj;
+                obj.rect.x = x0;
+                obj.rect.y = y0;
+                obj.rect.width = w;
+                obj.rect.height = h;
+                obj.label = class_idx;
+                obj.prob = box_prob;
+
+                objects.push_back(obj);
+            }
+
+        } // class loop
+        feat_ptr += feat_blob.w;
+
+    } // point anchor loop
+}
+
+static int detect_yolox(const cv::Mat& bgr, std::vector<Object>& objects)
+{
+    ncnn::Net yolox;
+
+    yolox.opt.use_vulkan_compute = true;
+    // yolox.opt.use_bf16_storage = true;
+
+    // Focus in yolov5
+    yolox.register_custom_layer("YoloV5Focus", YoloV5Focus_layer_creator);
+
+    // original pretrained model from https://github.com/Megvii-BaseDetection/YOLOX
+    // ncnn model param: https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_s_ncnn.tar.gz
+    yolox.load_param("yolox.param");
+    yolox.load_model("yolox.bin");
+
+    int img_w = bgr.cols;
+    int img_h = bgr.rows;
+
+    int w = img_w;
+    int h = img_h;
+    float scale = 1.f;
+    if (w > h)
+    {
+        scale = (float)YOLOX_TARGET_SIZE / w;
+        w = YOLOX_TARGET_SIZE;
+        h = h * scale;
+    }
+    else
+    {
+        scale = (float)YOLOX_TARGET_SIZE / h;
+        h = YOLOX_TARGET_SIZE;
+        w = w * scale;
+    }
+    ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR, img_w, img_h, w, h);
+
+    // pad to YOLOX_TARGET_SIZE rectangle
+    int wpad = YOLOX_TARGET_SIZE - w;
+    int hpad = YOLOX_TARGET_SIZE - h;
+    ncnn::Mat in_pad;
+    // different from yolov5, yolox only pad on bottom and right side,
+    // which means users don't need to extra padding info to decode boxes coordinate.
+    ncnn::copy_make_border(in, in_pad, 0, hpad, 0, wpad, ncnn::BORDER_CONSTANT, 114.f);
+
+    ncnn::Extractor ex = yolox.create_extractor();
+
+    ex.input("images", in_pad);
+
+    std::vector<Object> proposals;
+
+    {
+        ncnn::Mat out;
+        ex.extract("output", out);
+
+        static const int stride_arr[] = {8, 16, 32}; // might have stride=64 in YOLOX
+        std::vector<int> strides(stride_arr, stride_arr + sizeof(stride_arr) / sizeof(stride_arr[0]));
+        std::vector<GridAndStride> grid_strides;
+        generate_grids_and_stride(YOLOX_TARGET_SIZE, strides, grid_strides);
+        generate_yolox_proposals(grid_strides, out, YOLOX_CONF_THRESH, proposals);
+    }
+
+    // sort all proposals by score from highest to lowest
+    qsort_descent_inplace(proposals);
+
+    // apply nms with nms_threshold
+    std::vector<int> picked;
+    nms_sorted_bboxes(proposals, picked, YOLOX_NMS_THRESH);
+
+    int count = picked.size();
+
+    objects.resize(count);
+    for (int i = 0; i < count; i++)
+    {
+        objects[i] = proposals[picked[i]];
+
+        // adjust offset to original unpadded
+        float x0 = (objects[i].rect.x) / scale;
+        float y0 = (objects[i].rect.y) / scale;
+        float x1 = (objects[i].rect.x + objects[i].rect.width) / scale;
+        float y1 = (objects[i].rect.y + objects[i].rect.height) / scale;
+
+        // clip
+        x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f);
+        y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f);
+        x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f);
+        y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f);
+
+        objects[i].rect.x = x0;
+        objects[i].rect.y = y0;
+        objects[i].rect.width = x1 - x0;
+        objects[i].rect.height = y1 - y0;
+    }
+
+    return 0;
+}
+
+static void draw_objects(const cv::Mat& bgr, const std::vector<Object>& objects)
+{
+    static const char* class_names[] = {
+        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
+        "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
+        "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
+        "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
+        "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
+        "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
+        "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone",
+        "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
+        "hair drier", "toothbrush"
+    };
+
+    cv::Mat image = bgr.clone();
+
+    for (size_t i = 0; i < objects.size(); i++)
+    {
+        const Object& obj = objects[i];
+
+        fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob,
+                obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height);
+
+        cv::rectangle(image, obj.rect, cv::Scalar(255, 0, 0));
+
+        char text[256];
+        sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100);
+
+        int baseLine = 0;
+        cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
+
+        int x = obj.rect.x;
+        int y = obj.rect.y - label_size.height - baseLine;
+        if (y < 0)
+            y = 0;
+        if (x + label_size.width > image.cols)
+            x = image.cols - label_size.width;
+
+        cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),
+                      cv::Scalar(255, 255, 255), -1);
+
+        cv::putText(image, text, cv::Point(x, y + label_size.height),
+                    cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0));
+    }
+
+    cv::imshow("image", image);
+    cv::waitKey(0);
+}
+
+int main(int argc, char** argv)
+{
+    if (argc != 2)
+    {
+        fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]);
+        return -1;
+    }
+
+    const char* imagepath = argv[1];
+
+    cv::Mat m = cv::imread(imagepath, 1);
+    if (m.empty())
+    {
+        fprintf(stderr, "cv::imread %s failed\n", imagepath);
+        return -1;
+    }
+
+    std::vector<Object> objects;
+    detect_yolox(m, objects);
+
+    draw_objects(m, objects);
+
+    return 0;
+}
diff --git a/multimodal/YOLOX/demo/nebullvm/README.md b/multimodal/YOLOX/demo/nebullvm/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..24253505544c9ad7ec069a30e78e5a3d2d42dbbb
--- /dev/null
+++ b/multimodal/YOLOX/demo/nebullvm/README.md
@@ -0,0 +1,95 @@
+# **Accelerate YOLOX inference with nebullvm in Python**
+
+This document shows how to accelerate YOLOX inference time with nebullvm.
+
+[nebullvm](https://github.com/nebuly-ai/nebullvm) is an open-source library designed to accelerate AI inference of deep learning models in a few lines of code. nebullvm leverages state-of-the-art model optimization techniques such as deep learning compilers (TensorRT, Openvino, ONNX Runtime, TVM, TF Lite, DeepSparse, etc.), various quantization and compression strategies to achieve the maximum physically possible acceleration on the user's hardware.
+
+## Benchmarks
+Following are the results of the nebullvm optimization on YOLOX without loss of accuracy.
+For each model-hardware pairing, response time was evaluated as the average over 100 predictions. The test was run on Nvidia Tesla T4 (g4dn.xlarge) and Intel XEON Scalable (m6i.24xlarge and c6i.12xlarge) on AWS.
+
+| Model   | Hardware     | Unoptimized (ms)| Nebullvm optimized (ms) | Speedup |
+|---------|--------------|-----------------|-------------------------|---------|
+| YOLOX-s | g4dn.xlarge  |       13.6      |           9.0           |   1.5x  |
+| YOLOX-s | m6i.24xlarge |       32.7      |           8.8           |   3.7x  |
+| YOLOX-s | c6i.12xlarge |       34.4      |           12.4          |   2.8x  |
+| YOLOX-m | g4dn.xlarge  |       24.2      |           22.4          |   1.1x  |
+| YOLOX-m | m6i.24xlarge |       55.1      |           36.0          |   2.3x  |
+| YOLOX-m | c6i.12xlarge |       62.5      |           26.9          |   2.6x  |
+| YOLOX-l | g4dn.xlarge  |       84.4      |           80.5          |   1.5x  |
+| YOLOX-l | m6i.24xlarge |       88.0      |           33.7          |   2.6x  |
+| YOLOX-l | c6i.12xlarge |      102.8      |           54.2          |   1.9x  |
+| YOLOX-x | g4dn.xlarge  |       87.3      |           34.0          |   2.6x  |
+| YOLOX-x | m6i.24xlarge |      134.5      |           56.6          |   2.4x  |
+| YOLOX-x | c6i.12xlarge |      162.0      |           95.4          |   1.7x  |
+
+## Steps to accelerate YOLOX with nebullvm
+1. Download a YOLOX model from the original [readme](https://github.com/Megvii-BaseDetection/YOLOX)
+2. Optimize YOLOX with nebullvm
+3. Perform inference and compare the latency of the optimized model with that of the original model
+
+[Here](nebullvm_optimization.py) you can find a demo in python.
+
+
+First, let's install nebullvm. The simplest way is by using pip.
+```
+pip install nebullvm
+```
+Now, let's download one of YOLOX models and optimize it with nebullvm.
+
+```python
+# Import YOLOX model
+from yolox.exp import get_exp
+from yolox.data.data_augment import ValTransform
+
+exp = get_exp(None, 'yolox-s') # select model name
+model = exp.get_model()
+model.cuda()
+model.eval()
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+input_data =  [((torch.randn(1, 3, 640, 640).to(device), ), 0) for i in range(100)]
+
+# Run nebullvm optimization without performance loss
+optimized_model = optimize_model(model, input_data=input_data, optimization_time="constrained")
+```
+Find [here](nebullvm_optimize.py) the complete script in python with more details.
+
+In this example, we optimized YOLOX without any loss in accuracy. To further speed up the model by means of more aggressive optimization techniques, proceed as follows:
+- Set *optimization_time="unconstrained"*. With the unconstrained option, nebullvm will test time-consuming techniques such as pruning and quantization-aware training (QAT).
+- Set the *metric_drop_ths* parameter to be greater than zero (by default, *metric_drop_ths=0*). In this way, we will allow nebullvm to test optimization techniques that involve a tradeoff of some trade-off of a certain metric. For example, to test maximum acceleration with a minimum loss of accuracy of 3%, set *metric_drop_ths=0.03* and *metric="accuracy"*.
+For more information about nebullvm API, see [nebullvm documentation](https://github.com/nebuly-ai/nebullvm).
+
+
+Let's now compare the latency of the optimized model with that of the original model. 
+Note that before testing latency of the optimized model, it is necessary to perform some warmup runs, as some optimizers fine-tune certain internal parameters during the first few inferences after optimization.
+
+```python
+# Check perfomance
+warmup_iters = 30
+num_iters = 100
+
+# Unoptimized model perfomance
+with torch.no_grad():
+  for i in range(warmup_iters):
+    o = model(img)
+
+    start = time.time()
+    for i in range(num_iters):
+      o = model(img)
+stop = time.time()
+print(f"Average inference time of unoptimized YOLOX: {(stop - start)/num_iters*1000} ms")
+
+# Optimized model perfomance
+with torch.no_grad():
+  for i in range(warmup_iters):
+    res = model_opt(img)
+
+    start = time.time()
+    for i in range(num_iters):
+      res = model_opt(img)
+stop = time.time()
+print(f"Average inference time of YOLOX otpimized with nebullvm: {(stop - start)/num_iters*1000} ms")
+```
+Find [here](nebullvm_optimization.py) the complete script in python with more details.
diff --git a/multimodal/YOLOX/demo/nebullvm/nebullvm_optimization.py b/multimodal/YOLOX/demo/nebullvm/nebullvm_optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..b817baf62a23ab19a2ab34b26e021ffe4a266302
--- /dev/null
+++ b/multimodal/YOLOX/demo/nebullvm/nebullvm_optimization.py
@@ -0,0 +1,51 @@
+import torch
+import time
+from nebullvm.api.functions import optimize_model # Install DL compilers
+from yolox.exp import get_exp
+
+# Get YOLO model
+exp = get_exp(None, 'yolox-s') # select model name
+model = exp.get_model()
+model.cuda()
+model.eval()
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# Create dummy data for the optimizer
+input_data =  [((torch.randn(1, 3, 640, 640).to(device), ), 0) for i in range(100)] 
+
+# ---------- Optimization ---------- 
+optimized_model = optimize_model(model, input_data=input_data, optimization_time="constrained")  # Optimization without performance loss
+
+
+# ---------- Benchmarks ---------- 
+# Select image to test the latency of the optimized model
+
+# Create dummy image
+img = torch.randn(1, 3, 640, 640).to(device)
+
+# Check perfomance
+warmup_iters = 30
+num_iters = 100
+
+# Unptimized model perfomance
+with torch.no_grad():
+  for i in range(warmup_iters):
+    o = model(img)
+
+    start = time.time()
+    for i in range(num_iters):
+      o = model(img)
+stop = time.time()
+print(f"Average inference time of unoptimized YOLOX: {(stop - start)/num_iters*1000} ms")
+
+# Optimized model perfomance
+with torch.no_grad():
+  for i in range(warmup_iters):
+    res = optimized_model(img)
+
+    start = time.time()
+    for i in range(num_iters):
+      res = optimized_model(img)
+stop = time.time()
+print(f"Average inference time of YOLOX otpimized with nebullvm: {(stop - start)/num_iters*1000} ms")
diff --git a/multimodal/YOLOX/docs/.gitignore b/multimodal/YOLOX/docs/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..9c5f57827018f6a0435036d9515314e34604b9fe
--- /dev/null
+++ b/multimodal/YOLOX/docs/.gitignore
@@ -0,0 +1 @@
+_build
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/Makefile b/multimodal/YOLOX/docs/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..ce61fb6a84ca97d25d833b64fa1b66d8f3e6ae7f
--- /dev/null
+++ b/multimodal/YOLOX/docs/Makefile
@@ -0,0 +1,19 @@
+# Minimal makefile for Sphinx documentation
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+# You can set these variables from the command line.
+SPHINXOPTS    =
+SPHINXBUILD   = sphinx-build
+SOURCEDIR     = .
+BUILDDIR      = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/_static/css/custom.css b/multimodal/YOLOX/docs/_static/css/custom.css
new file mode 100644
index 0000000000000000000000000000000000000000..81f77f57d08d8be8c876906fb6455169bec1b39d
--- /dev/null
+++ b/multimodal/YOLOX/docs/_static/css/custom.css
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * some extra css to make markdown look similar between github/sphinx
+ */
+
+/*
+ * Below is for install.md:
+ */
+ .rst-content code {
+    white-space: pre;
+    border: 0px;
+  }
+  
+  .rst-content th {
+    border: 1px solid #e1e4e5;
+  }
+  
+  .rst-content th p {
+    /* otherwise will be default 24px for regular paragraph */
+    margin-bottom: 0px;
+  }
+  
+  .rst-content .line-block {
+    /* otherwise will be 24px */
+    margin-bottom: 0px;
+  }
+  
+  div.section > details {
+    padding-bottom: 1em;
+  }
+  
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/assignment_visualization.md b/multimodal/YOLOX/docs/assignment_visualization.md
new file mode 100644
index 0000000000000000000000000000000000000000..4bc7791f92ad58f7071d25bb668a18d144a4b6c4
--- /dev/null
+++ b/multimodal/YOLOX/docs/assignment_visualization.md
@@ -0,0 +1,29 @@
+# Visualize label assignment
+
+This tutorial explains how to visualize your label asssignment result when training with YOLOX.
+
+## 1. Visualization command
+
+We provide a visualization tool to help you visualize your label assignment result. You can find it in [`tools/visualize_assignment.py`](../tools/visualize_assign.py).
+
+Here is an example of command to visualize your label assignment result:
+
+```shell
+python3 tools/visualize_assign.py -f /path/to/your/exp.py yolox-s -d 1 -b 8 --max-batch 2
+```
+
+`max-batch` here means the maximum number of batches to visualize. The default value is 1, which the tool means only visualize the first batch.
+
+By the way, the mosaic augmentation is used in default dataloader, so you can also see the mosaic result here.
+
+After running the command, the logger will show you where the visualization result is saved, let's open it and into the step 2.
+
+## 2. Check the visualization result
+
+Here is an example of visualization result:
+<div align="center"><img src="../assets/assignment.png" width="640"></div>
+
+Those dots in one box is the matched anchor of gt box. **The color of dots is the same as the color of the box** to help you determine which object is assigned to the anchor. Note the box and dots are **instance level** visualization, which means the same class may have different colors.  
+**If the gt box doesn't match any anchor, the box will be marked as red and the red text "unmatched" will be drawn over the box**.
+
+Please feel free to open an issue if you have any questions.
diff --git a/multimodal/YOLOX/docs/cache.md b/multimodal/YOLOX/docs/cache.md
new file mode 100755
index 0000000000000000000000000000000000000000..66aded7cb7e83ef4f568912cd7d5c751a74006b9
--- /dev/null
+++ b/multimodal/YOLOX/docs/cache.md
@@ -0,0 +1,97 @@
+# Cache Custom Data
+
+The caching feature is specifically tailored for users with ample memory resources. However, we still offer the option to cache data to disk, but disk performance can vary and may not guarantee optimal user experience. Implementing custom dataset RAM caching is also more straightforward and user-friendly compared to disk caching. With a few simple modifications, users can expect to see a significant increase in training speed, with speeds nearly double that of non-cached datasets.
+
+This page explains how to cache your own custom data with YOLOX.
+
+## 0. Before you start
+
+**Step1** Clone this repo and follow the [README](../README.md) to install YOLOX.
+
+**Stpe2** Read the [Training on custom data](./train_custom_data.md) tutorial to understand how to prepare your custom data.
+
+## 1. Inheirit from `CacheDataset`
+
+
+**Step1** Create a custom dataset that inherits from the `CacheDataset` class. Note that whether inheriting from `Dataset` or `CacheDataset `, the `__init__()` method of your custom dataset should take the following keyword arguments: `input_dimension`, `cache`, and `cache_type`. Also, call `super().__init__()` and pass in `input_dimension`, `num_imgs`, `cache`, and `cache_type` as input, where `num_imgs` is the size of the dataset.
+
+**Step2** Implement the abstract function `read_img(self, index, use_cache=True)` of parent class and decorate it with `@cache_read_img`.  This function takes an `index` as input and returns an `image`, and the returned image will be used for caching. It is recommended to put all repetitive and fixed post-processing operations on the image in this function to reduce the post-processing time of the image during training.
+
+```python
+# CustomDataset.py
+from yolox.data.datasets import CacheDataset, cache_read_img
+
+class CustomDataset(CacheDataset):
+    def __init__(self, input_dimension, cache, cache_type, *args, **kwargs):
+        # Get the required keyword arguments of super().__init__()
+        super().__init__(
+            input_dimension=input_dimension,
+            num_imgs=num_imgs,
+            cache=cache,
+            cache_type=cache_type
+        )
+        # ...
+
+    @cache_read_img
+    def read_img(self, index, use_cache=True):
+        # get image ...
+        # (optional) repetitive and fixed post-processing operations for image
+        return image
+```
+
+## 2. Create your Exp file and return your custom dataset
+
+**Step1** Create a new class that inherits from the `Exp` class provided by the `yolox_base.py`. Override the `get_dataset()` and `get_eval_dataset()` method to return an instance of your custom dataset.
+
+**Step2** Implement your own `get_evaluator` method to return an instance of your custom evaluator.
+
+```python
+# CustomeExp.py
+from yolox.exp import Exp as MyExp
+
+class Exp(MyExp):
+    def get_dataset(self, cache, cache_type: str = "ram"):
+        return CustomDataset(
+            input_dimension=self.input_size,
+            cache=cache,
+            cache_type=cache_type
+        )
+
+    def get_eval_dataset(self):
+        return CustomDataset(
+            input_dimension=self.input_size,
+        )
+
+    def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
+        return CustomEvaluator(
+            dataloader=self.get_eval_loader(batch_size, is_distributed, testdev=testdev, legacy=legacy),
+            img_size=self.test_size,
+            confthre=self.test_conf,
+            nmsthre=self.nmsthre,
+            num_classes=self.num_classes,
+            testdev=testdev,
+        )
+```
+
+**(Optional)** `get_data_loader` and `get_eval_loader` are now a default behavior in `yolox_base.py` and generally do not need to be changed. If you have to change `get_data_loader`, you need to add the following code at the beginning.
+
+```python
+# CustomeExp.py
+from yolox.exp import Exp as MyExp
+
+class Exp(MyExp):
+    def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
+        if self.dataset is None:
+            with wait_for_the_master():
+                assert cache_img is None
+                self.dataset = self.get_dataset(cache=False, cache_type=cache_img)
+        # ...
+
+```
+
+## 3. Cache to Disk
+It's important to note that the `cache_type` can be `"ram"` or `"disk"`, depending on where you want to cache your dataset. If you choose `"disk"`, you need to pass in additional parameters to `super().__init__()` of `CustomDataset`: `data_dir`, `cache_dir_name`, `path_filename`.
+
+- `data_dir`: the root directory of the dataset, e.g. `/path/to/COCO`.
+- `cache_dir_name`: the name of the directory to cache to disk, for example `"custom_cache"`, then the files cached to disk will be saved under `/path/to/COCO/custom_cache`.
+- `path_filename`: a list of paths to the data relative to the `data_dir`, e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`, then `path_filename = ['train/1.jpg', ' train/2.jpg']`.
diff --git a/multimodal/YOLOX/docs/conf.py b/multimodal/YOLOX/docs/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d529682b248d7fb33668e0a4c56f5b178efa675
--- /dev/null
+++ b/multimodal/YOLOX/docs/conf.py
@@ -0,0 +1,384 @@
+# -*- coding: utf-8 -*-
+# Code are based on
+# https://github.com/facebookresearch/detectron2/blob/master/docs/conf.py
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+# flake8: noqa
+
+# Configuration file for the Sphinx documentation builder.
+#
+# This file does only contain a selection of the most common options. For a
+# full list see the documentation:
+# http://www.sphinx-doc.org/en/master/config
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+from unittest import mock
+from sphinx.domains import Domain
+from typing import Dict, List, Tuple
+
+# The theme to use for HTML and HTML Help pages.  See the documentation for
+# a list of builtin themes.
+#
+import sphinx_rtd_theme
+
+
+class GithubURLDomain(Domain):
+    """
+    Resolve certain links in markdown files to github source.
+    """
+
+    name = "githuburl"
+    ROOT = "https://github.com/Megvii-BaseDetection/YOLOX"
+    # LINKED_DOC = ["tutorials/install", "tutorials/getting_started"]
+    LINKED_DOC = ["tutorials/install",]
+
+    def resolve_any_xref(self, env, fromdocname, builder, target, node, contnode):
+        github_url = None
+        if not target.endswith("html") and target.startswith("../../"):
+            url = target.replace("../", "")
+            github_url = url
+        if fromdocname in self.LINKED_DOC:
+            # unresolved links in these docs are all github links
+            github_url = target
+
+        if github_url is not None:
+            if github_url.endswith("MODEL_ZOO") or github_url.endswith("README"):
+                # bug of recommonmark.
+                # https://github.com/readthedocs/recommonmark/blob/ddd56e7717e9745f11300059e4268e204138a6b1/recommonmark/parser.py#L152-L155
+                github_url += ".md"
+            print("Ref {} resolved to github:{}".format(target, github_url))
+            contnode["refuri"] = self.ROOT + github_url
+            return [("githuburl:any", contnode)]
+        else:
+            return []
+
+
+# to support markdown
+from recommonmark.parser import CommonMarkParser
+
+sys.path.insert(0, os.path.abspath("../"))
+os.environ["_DOC_BUILDING"] = "True"
+DEPLOY = os.environ.get("READTHEDOCS") == "True"
+
+
+# -- Project information -----------------------------------------------------
+
+# fmt: off
+try:
+    import torch  # noqa
+except ImportError:
+    for m in [
+        "torch", "torchvision", "torch.nn", "torch.nn.parallel", "torch.distributed", "torch.multiprocessing", "torch.autograd",
+        "torch.autograd.function", "torch.nn.modules", "torch.nn.modules.utils", "torch.utils", "torch.utils.data", "torch.onnx",
+        "torchvision", "torchvision.ops",
+    ]:
+        sys.modules[m] = mock.Mock(name=m)
+    sys.modules['torch'].__version__ = "1.7"  # fake version
+    HAS_TORCH = False
+else:
+    try:
+        torch.ops.yolox = mock.Mock(name="torch.ops.yolox")
+    except:
+        pass
+    HAS_TORCH = True
+
+for m in [
+    "cv2", "scipy", "portalocker", "yolox._C",
+    "pycocotools", "pycocotools.mask", "pycocotools.coco", "pycocotools.cocoeval",
+    "google", "google.protobuf", "google.protobuf.internal", "onnx",
+    "caffe2", "caffe2.proto", "caffe2.python", "caffe2.python.utils", "caffe2.python.onnx", "caffe2.python.onnx.backend",
+]:
+    sys.modules[m] = mock.Mock(name=m)
+# fmt: on
+sys.modules["cv2"].__version__ = "3.4"
+
+import yolox  # isort: skip
+
+# if HAS_TORCH:
+#     from detectron2.utils.env import fixup_module_metadata
+
+#     fixup_module_metadata("torch.nn", torch.nn.__dict__)
+#     fixup_module_metadata("torch.utils.data", torch.utils.data.__dict__)
+
+
+project = "YOLOX"
+copyright = "2021-2021, YOLOX contributors"
+author = "YOLOX contributors"
+
+# The short X.Y version
+version = yolox.__version__
+# The full version, including alpha/beta/rc tags
+release = version
+
+
+# -- General configuration ---------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+#
+needs_sphinx = "3.0"
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+    "recommonmark",
+    "sphinx.ext.autodoc",
+    "sphinx.ext.napoleon",
+    "sphinx.ext.intersphinx",
+    "sphinx.ext.todo",
+    "sphinx.ext.coverage",
+    "sphinx.ext.mathjax",
+    "sphinx.ext.viewcode",
+    "sphinx.ext.githubpages",
+    'sphinx_markdown_tables',
+]
+
+# -- Configurations for plugins ------------
+napoleon_google_docstring = True
+napoleon_include_init_with_doc = True
+napoleon_include_special_with_doc = True
+napoleon_numpy_docstring = False
+napoleon_use_rtype = False
+autodoc_inherit_docstrings = False
+autodoc_member_order = "bysource"
+
+if DEPLOY:
+    intersphinx_timeout = 10
+else:
+    # skip this when building locally
+    intersphinx_timeout = 0.5
+intersphinx_mapping = {
+    "python": ("https://docs.python.org/3.6", None),
+    "numpy": ("https://docs.scipy.org/doc/numpy/", None),
+    "torch": ("https://pytorch.org/docs/master/", None),
+}
+# -------------------------
+
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+source_suffix = [".rst", ".md"]
+
+# The master toctree document.
+master_doc = "index"
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "build", "README.md", "tutorials/README.md"]
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = "sphinx"
+
+
+# -- Options for HTML output -------------------------------------------------
+
+html_theme = "sphinx_rtd_theme"
+html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
+
+# Theme options are theme-specific and customize the look and feel of a theme
+# further.  For a list of options available for each theme, see the
+# documentation.
+#
+# html_theme_options = {}
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ["_static"]
+html_css_files = ["css/custom.css"]
+
+# Custom sidebar templates, must be a dictionary that maps document names
+# to template names.
+#
+# The default sidebars (for documents that don't match any pattern) are
+# defined by theme itself.  Builtin themes are using these templates by
+# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
+# 'searchbox.html']``.
+#
+# html_sidebars = {}
+
+
+# -- Options for HTMLHelp output ---------------------------------------------
+
+# Output file base name for HTML help builder.
+htmlhelp_basename = "yoloxdoc"
+
+
+# -- Options for LaTeX output ------------------------------------------------
+
+latex_elements = {
+    # The paper size ('letterpaper' or 'a4paper').
+    #
+    # 'papersize': 'letterpaper',
+    # The font size ('10pt', '11pt' or '12pt').
+    #
+    # 'pointsize': '10pt',
+    # Additional stuff for the LaTeX preamble.
+    #
+    # 'preamble': '',
+    # Latex figure (float) alignment
+    #
+    # 'figure_align': 'htbp',
+}
+
+# Grouping the document tree into LaTeX files. List of tuples
+# (source start file, target name, title,
+#  author, documentclass [howto, manual, or own class]).
+latex_documents = [
+    (master_doc, "yolox.tex", "yolox Documentation", "yolox contributors", "manual")
+]
+
+
+# -- Options for manual page output ------------------------------------------
+
+# One entry per manual page. List of tuples
+# (source start file, name, description, authors, manual section).
+man_pages = [(master_doc, "YOLOX", "YOLOX Documentation", [author], 1)]
+
+
+# -- Options for Texinfo output ----------------------------------------------
+
+# Grouping the document tree into Texinfo files. List of tuples
+# (source start file, target name, title, author,
+#  dir menu entry, description, category)
+texinfo_documents = [
+    (
+        master_doc,
+        "YOLOX",
+        "YOLOX Documentation",
+        author,
+        "YOLOX",
+        "One line description of project.",
+        "Miscellaneous",
+    )
+]
+
+
+# -- Options for todo extension ----------------------------------------------
+
+# If true, `todo` and `todoList` produce output, else they produce nothing.
+todo_include_todos = True
+
+
+def autodoc_skip_member(app, what, name, obj, skip, options):
+    # we hide something deliberately
+    if getattr(obj, "__HIDE_SPHINX_DOC__", False):
+        return True
+
+    # Hide some that are deprecated or not intended to be used
+    HIDDEN = {
+        "ResNetBlockBase",
+        "GroupedBatchSampler",
+        "build_transform_gen",
+        "export_caffe2_model",
+        "export_onnx_model",
+        "apply_transform_gens",
+        "TransformGen",
+        "apply_augmentations",
+        "StandardAugInput",
+        "build_batch_data_loader",
+        "draw_panoptic_seg_predictions",
+        "WarmupCosineLR",
+        "WarmupMultiStepLR",
+    }
+    try:
+        if name in HIDDEN or (
+            hasattr(obj, "__doc__") and obj.__doc__.lower().strip().startswith("deprecated")
+        ):
+            print("Skipping deprecated object: {}".format(name))
+            return True
+    except:
+        pass
+    return skip
+
+
+# _PAPER_DATA = {
+#     "resnet": ("1512.03385", "Deep Residual Learning for Image Recognition"),
+#     "fpn": ("1612.03144", "Feature Pyramid Networks for Object Detection"),
+#     "mask r-cnn": ("1703.06870", "Mask R-CNN"),
+#     "faster r-cnn": (
+#         "1506.01497",
+#         "Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks",
+#     ),
+#     "deformconv": ("1703.06211", "Deformable Convolutional Networks"),
+#     "deformconv2": ("1811.11168", "Deformable ConvNets v2: More Deformable, Better Results"),
+#     "panopticfpn": ("1901.02446", "Panoptic Feature Pyramid Networks"),
+#     "retinanet": ("1708.02002", "Focal Loss for Dense Object Detection"),
+#     "cascade r-cnn": ("1712.00726", "Cascade R-CNN: Delving into High Quality Object Detection"),
+#     "lvis": ("1908.03195", "LVIS: A Dataset for Large Vocabulary Instance Segmentation"),
+#     "rrpn": ("1703.01086", "Arbitrary-Oriented Scene Text Detection via Rotation Proposals"),
+#     "imagenet in 1h": ("1706.02677", "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"),
+#     "xception": ("1610.02357", "Xception: Deep Learning with Depthwise Separable Convolutions"),
+#     "mobilenet": (
+#         "1704.04861",
+#         "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications",
+#     ),
+#     "deeplabv3+": (
+#         "1802.02611",
+#         "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation",
+#     ),
+#     "dds": ("2003.13678", "Designing Network Design Spaces"),
+#     "scaling": ("2103.06877", "Fast and Accurate Model Scaling"),
+# }
+
+
+# def paper_ref_role(
+#     typ: str,
+#     rawtext: str,
+#     text: str,
+#     lineno: int,
+#     inliner,
+#     options: Dict = {},
+#     content: List[str] = [],
+# ):
+#     """
+#     Parse :paper:`xxx`. Similar to the "extlinks" sphinx extension.
+#     """
+#     from docutils import nodes, utils
+#     from sphinx.util.nodes import split_explicit_title
+
+#     text = utils.unescape(text)
+#     has_explicit_title, title, link = split_explicit_title(text)
+#     link = link.lower()
+#     if link not in _PAPER_DATA:
+#         inliner.reporter.warning("Cannot find paper " + link)
+#         paper_url, paper_title = "#", link
+#     else:
+#         paper_url, paper_title = _PAPER_DATA[link]
+#         if "/" not in paper_url:
+#             paper_url = "https://arxiv.org/abs/" + paper_url
+#     if not has_explicit_title:
+#         title = paper_title
+#     pnode = nodes.reference(title, title, internal=False, refuri=paper_url)
+#     return [pnode], []
+
+
+def setup(app):
+    from recommonmark.transform import AutoStructify
+
+    app.add_domain(GithubURLDomain)
+    app.connect("autodoc-skip-member", autodoc_skip_member)
+    # app.add_role("paper", paper_ref_role)
+    app.add_config_value(
+        "recommonmark_config",
+        {"enable_math": True, "enable_inline_math": True, "enable_eval_rst": True},
+        True,
+    )
+    app.add_transform(AutoStructify)
diff --git a/multimodal/YOLOX/docs/demo/megengine_cpp_readme.md b/multimodal/YOLOX/docs/demo/megengine_cpp_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..dbadb36bcdc06a7d62e99d7f2f0c59b40231e1b7
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/megengine_cpp_readme.md
@@ -0,0 +1 @@
+../../demo/MegEngine/cpp/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/demo/megengine_py_readme.md b/multimodal/YOLOX/docs/demo/megengine_py_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..3029e4a207535d150eeda5ad392306ffcb119593
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/megengine_py_readme.md
@@ -0,0 +1 @@
+../../demo/MegEngine/python/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/demo/ncnn_android_readme.md b/multimodal/YOLOX/docs/demo/ncnn_android_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..b623071454b4e1b10fa311da5941aa2ab4a406a7
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/ncnn_android_readme.md
@@ -0,0 +1 @@
+../../demo/ncnn/android/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/demo/ncnn_cpp_readme.md b/multimodal/YOLOX/docs/demo/ncnn_cpp_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..c00c01b06cd411581d3f269fd54a47e9d702b279
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/ncnn_cpp_readme.md
@@ -0,0 +1 @@
+../../demo/ncnn/cpp/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/demo/onnx_readme.md b/multimodal/YOLOX/docs/demo/onnx_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..bd85ab19678b58cf2a33ac1bdd3cecb449f951a2
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/onnx_readme.md
@@ -0,0 +1 @@
+../../demo/ONNXRuntime/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/demo/openvino_cpp_readme.md b/multimodal/YOLOX/docs/demo/openvino_cpp_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..3f455940a26a0cc4ce6b12fee4bb97725055458a
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/openvino_cpp_readme.md
@@ -0,0 +1 @@
+../../demo/OpenVINO/cpp/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/demo/openvino_py_readme.md b/multimodal/YOLOX/docs/demo/openvino_py_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..8adb770a576450bcc507861f98a36dd43bf00019
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/openvino_py_readme.md
@@ -0,0 +1 @@
+../../demo/OpenVINO/python/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/demo/trt_cpp_readme.md b/multimodal/YOLOX/docs/demo/trt_cpp_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..6efafeda5d98609265a5954115db34916a7aef4e
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/trt_cpp_readme.md
@@ -0,0 +1 @@
+../../demo/TensorRT/cpp/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/demo/trt_py_readme.md b/multimodal/YOLOX/docs/demo/trt_py_readme.md
new file mode 120000
index 0000000000000000000000000000000000000000..44df8914da8844e02d99a58744bbc21a42261d2a
--- /dev/null
+++ b/multimodal/YOLOX/docs/demo/trt_py_readme.md
@@ -0,0 +1 @@
+../../demo/TensorRT/python/README.md
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/freeze_module.md b/multimodal/YOLOX/docs/freeze_module.md
new file mode 100644
index 0000000000000000000000000000000000000000..421d95cd96d0a876f17ad57af899b2e06f0addbd
--- /dev/null
+++ b/multimodal/YOLOX/docs/freeze_module.md
@@ -0,0 +1,37 @@
+# Freeze module
+
+This page guide users to freeze module in YOLOX.  
+Exp controls everything in YOLOX, so let's start from creating an Exp object.
+
+## 1. Create your own expermiment object
+
+We take an example of YOLOX-S model on COCO dataset to give a more clear guide.
+
+Import the config you want (or write your own Exp object inherit from `yolox.exp.BaseExp`).
+```python
+from yolox.exp.default.yolox_s import Exp as MyExp
+```
+
+## 2. Override `get_model` method
+
+Here is a simple code to freeze backbone (FPN not included) of module.
+```python
+class Exp(MyExp):
+
+    def get_model(self):
+        from yolox.utils import freeze_module
+        model = super().get_model()
+        freeze_module(model.backbone.backbone)
+        return model
+```
+if you only want to freeze FPN, `freeze_module(model.backbone)` might help.
+
+## 3. Train
+Suppose that the path of your Exp  is `/path/to/my_exp.py`, use the following command to train your model.
+```bash
+python3 -m yolox.tools.train -f /path/to/my_exp.py
+```
+For more details of training, run the following command.
+```bash
+python3 -m yolox.tools.train --help
+```
diff --git a/multimodal/YOLOX/docs/index.rst b/multimodal/YOLOX/docs/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..76e4d08d75b623909c6fed9ec58ffba6f08ab537
--- /dev/null
+++ b/multimodal/YOLOX/docs/index.rst
@@ -0,0 +1,32 @@
+
+Welcome to YOLOX's documentation!
+======================================
+
+.. image:: ../assets/logo.png
+
+.. toctree::
+   :maxdepth: 2
+   :caption: Quick Run
+   
+   quick_run
+   model_zoo
+
+.. toctree::
+   :maxdepth: 2
+   :caption: Tutorials
+
+   train_custom_data
+
+.. toctree::
+   :maxdepth: 2
+   :caption: Demployment
+
+   demo/trt_py_readme
+   demo/trt_cpp_readme
+   demo/megengine_cpp_readme
+   demo/megengine_py_readme
+   demo/ncnn_android_readme
+   demo/ncnn_cpp_readme
+   demo/onnx_readme
+   demo/openvino_py_readme
+   demo/openvino_cpp_readme
\ No newline at end of file
diff --git a/multimodal/YOLOX/docs/manipulate_training_image_size.md b/multimodal/YOLOX/docs/manipulate_training_image_size.md
new file mode 100644
index 0000000000000000000000000000000000000000..7a4e8560b672b931f3f64a2f502ea0e863ae2338
--- /dev/null
+++ b/multimodal/YOLOX/docs/manipulate_training_image_size.md
@@ -0,0 +1,59 @@
+# Manipulating Your Training Image Size
+
+This tutorial explains how to control your image size when training on your own data.
+
+## 1. Introduction
+
+There are 3 hyperparamters control the training size:
+
+- self.input_size = (640, 640) &emsp; #(height, width)
+- self.multiscale_range = 5
+- self.random_size = (14, 26)
+
+There is 1 hyperparameter constrols the testing size:
+
+- self.test_size = (640, 640)
+
+The self.input_size is suggested to set to the same value as self.test_size. By default, it is set to (640, 640) for most models and (416, 416) for yolox-tiny and yolox-nano.
+
+## 2. Multi Scale Training
+
+When training on your custom dataset, you can use multiscale training in 2 ways:
+
+1. **【Default】Only specifying the self.input_size and leaving others unchanged.**
+
+   If so, the actual multiscale sizes range from:
+
+   [self.input_size[0] - self.multiscale_range\*32,  self.input_size[0] + self.multiscale_range\*32]
+
+   For example, if you only set:
+
+   ```python
+   self.input_size = (640, 640)
+   ```
+
+   the actual multiscale range is [640 - 5*32, 640 + 5\*32], i.e., [480, 800].
+
+   You can modify self.multiscale_range to change the multiscale range.
+
+2. **Simultaneously specifying the self.input_size and self.random_size**
+
+   ```python
+   self.input_size = (416, 416)
+   self.random_size = (10, 20)
+   ```
+
+   In this case, the actual multiscale range is [self.random_size[0]\*32, self.random_size[1]\*32], i.e., [320, 640]
+
+   **Note: You must specify the self.input_size because it is used for initializing resize aug in dataset.**
+
+## 3. Single Scale Training
+
+If you want to train in a single scale. You need to specify the self.input_size and self.multiscale_range=0:
+
+```python
+self.input_size = (416, 416)
+self.multiscale_range = 0
+```
+
+**DO NOT** set the self.random_size.
diff --git a/multimodal/YOLOX/docs/model_zoo.md b/multimodal/YOLOX/docs/model_zoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..9bde131b943c07e1d5ee17b5740327d7e20eda41
--- /dev/null
+++ b/multimodal/YOLOX/docs/model_zoo.md
@@ -0,0 +1,42 @@
+# Model Zoo
+
+## Standard Models.
+
+|Model |size |mAP<sup>val<br>0.5:0.95 |mAP<sup>test<br>0.5:0.95 | Speed V100<br>(ms) | Params<br>(M) |FLOPs<br>(G)| weights |
+| ------        |:---: | :---:    | :---:       |:---:     |:---:  | :---: | :----: |
+|[YOLOX-s](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_s.py)    |640  |40.5 |40.5      |9.8      |9.0 | 26.8 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_s.pth) |
+|[YOLOX-m](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_m.py)    |640  |46.9 |47.2      |12.3     |25.3 |73.8| [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_m.pth) |
+|[YOLOX-l](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_l.py)    |640  |49.7 |50.1      |14.5     |54.2| 155.6 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_l.pth) |
+|[YOLOX-x](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_x.py)   |640   |51.1 |**51.5**  | 17.3    |99.1 |281.9 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_x.pth) |
+|[YOLOX-Darknet53](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolov3.py)   |640  | 47.7 | 48.0 | 11.1 |63.7 | 185.3 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_darknet.pth) 
+
+<details>
+<summary>Legacy models</summary>
+
+|Model |size |mAP<sup>test<br>0.5:0.95 | Speed V100<br>(ms) | Params<br>(M) |FLOPs<br>(G)| weights |
+| ------        |:---: | :---:       |:---:     |:---:  | :---: | :----: |
+|[YOLOX-s](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_s.py)    |640  |39.6      |9.8     |9.0 | 26.8 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EW62gmO2vnNNs5npxjzunVwB9p307qqygaCkXdTO88BLUg?e=NMTQYw)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_s.pth) |
+|[YOLOX-m](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_m.py)    |640  |46.4      |12.3     |25.3 |73.8| [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/ERMTP7VFqrVBrXKMU7Vl4TcBQs0SUeCT7kvc-JdIbej4tQ?e=1MDo9y)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_m.pth) |
+|[YOLOX-l](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_l.py)    |640  |50.0  |14.5 |54.2| 155.6 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EWA8w_IEOzBKvuueBqfaZh0BeoG5sVzR-XYbOJO4YlOkRw?e=wHWOBE)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_l.pth) |
+|[YOLOX-x](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_x.py)   |640  |**51.2**      | 17.3 |99.1 |281.9 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EdgVPHBziOVBtGAXHfeHI5kBza0q9yyueMGdT0wXZfI1rQ?e=tABO5u)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_x.pth) |
+|[YOLOX-Darknet53](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolov3.py)   |640  | 47.4      | 11.1 |63.7 | 185.3 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EZ-MV1r_fMFPkPrNjvbJEMoBLOLAnXH-XKEB77w8LhXL6Q?e=mf6wOc)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_darknet53.pth) |
+
+</details>
+
+## Light Models.
+
+|Model |size |mAP<sup>val<br>0.5:0.95 | Params<br>(M) |FLOPs<br>(G)| weights |
+| ------        |:---:  |  :---:       |:---:     |:---:  | :---: |
+|[YOLOX-Nano](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_nano.py) |416  |25.8  | 0.91 |1.08 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_nano.pth) |
+|[YOLOX-Tiny](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_tiny.py) |416  |32.8 | 5.06 |6.45 | [github](https://github.com/Megvii-BaseDetection/YOLOX/releases/download/0.1.1rc0/yolox_tiny.pth) |
+
+
+<details>
+<summary>Legacy models</summary>
+
+|Model |size |mAP<sup>val<br>0.5:0.95 | Params<br>(M) |FLOPs<br>(G)| weights |
+| ------        |:---:  |  :---:       |:---:     |:---:  | :---: |
+|[YOLOX-Nano](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_nano.py) |416  |25.3  | 0.91 |1.08 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EdcREey-krhLtdtSnxolxiUBjWMy6EFdiaO9bdOwZ5ygCQ?e=yQpdds)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_nano.pth) |
+|[YOLOX-Tiny](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/default/yolox_tiny.py) |416  |32.8 | 5.06 |6.45 | [onedrive](https://megvii-my.sharepoint.cn/:u:/g/personal/gezheng_megvii_com/EbZuinX5X1dJmNy8nqSRegABWspKw3QpXxuO82YSoFN1oQ?e=Q7V7XE)/[github](https://github.com/Megvii-BaseDetection/storage/releases/download/0.0.1/yolox_tiny_32dot8.pth) |
+
+</details>
diff --git a/multimodal/YOLOX/docs/quick_run.md b/multimodal/YOLOX/docs/quick_run.md
new file mode 100644
index 0000000000000000000000000000000000000000..f00bb995ba7859cfd6b2a8e540e0d8980d4adc11
--- /dev/null
+++ b/multimodal/YOLOX/docs/quick_run.md
@@ -0,0 +1,127 @@
+
+# Get Started
+
+## 1.Installation
+
+Step1. Install YOLOX.
+```shell
+git clone git@github.com:Megvii-BaseDetection/YOLOX.git
+cd YOLOX
+pip3 install -U pip && pip3 install -r requirements.txt
+pip3 install -v -e .  # or  python3 setup.py develop
+```
+Step2. Install [pycocotools](https://github.com/cocodataset/cocoapi).
+
+```shell
+pip3 install cython; pip3 install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
+```
+
+## 2.Demo
+
+Step1. Download a pretrained model from the benchmark table.
+
+Step2. Use either -n or -f to specify your detector's config. For example:
+
+```shell
+python tools/demo.py image -n yolox-s -c /path/to/your/yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device [cpu/gpu]
+```
+or
+```shell
+python tools/demo.py image -f exps/default/yolox_s.py -c /path/to/your/yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device [cpu/gpu]
+```
+Demo for video:
+```shell
+python tools/demo.py video -n yolox-s -c /path/to/your/yolox_s.pth --path /path/to/your/video --conf 0.25 --nms 0.45 --tsize 640 --save_result --device [cpu/gpu]
+```
+
+
+## 3.Reproduce our results on COCO
+
+Step1. Prepare COCO dataset
+```shell
+cd <YOLOX_HOME>
+ln -s /path/to/your/COCO ./datasets/COCO
+```
+
+Step2. Reproduce our results on COCO by specifying -n:
+
+```shell
+python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
+                         yolox-m
+                         yolox-l
+                         yolox-x
+```
+* -d: number of gpu devices
+* -b: total batch size, the recommended number for -b is num-gpu * 8
+* --fp16: mixed precision training
+* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
+
+**Weights & Biases for Logging**
+
+To use W&B for logging, install wandb in your environment and log in to your W&B account using
+
+```shell
+pip install wandb
+wandb login
+```
+
+Log in to your W&B account
+
+To start logging metrics to W&B during training add the flag `--logger` to the previous command and use the prefix "wandb-" to specify arguments for initializing the wandb run.
+
+```shell
+python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
+                         yolox-m
+                         yolox-l
+                         yolox-x
+```
+
+More WandbLogger arguments include
+
+```shell
+python tools/train.py .... --logger wandb wandb-project <project-name> \
+                wandb-name <run-name> \
+                wandb-id <run-id> \
+                wandb-save_dir <save-dir> \
+                wandb-num_eval_images <num-images> \
+                wandb-log_checkpoints <bool>
+```
+
+More information available [here](https://docs.wandb.ai/guides/integrations/other/yolox).
+
+**Multi Machine Training**
+
+We also support multi-nodes training. Just add the following args:
+* --num\_machines: num of your total training nodes
+* --machine\_rank: specify the rank of each node
+
+When using -f, the above commands are equivalent to:
+
+```shell
+python tools/train.py -f exps/default/yolox-s.py -d 8 -b 64 --fp16 -o [--cache]
+                         exps/default/yolox-m.py
+                         exps/default/yolox-l.py
+                         exps/default/yolox-x.py
+```
+
+## 4.Evaluation
+
+We support batch testing for fast evaluation:
+
+```shell
+python tools/eval.py -n  yolox-s -c yolox_s.pth -b 64 -d 8 --conf 0.001 [--fp16] [--fuse]
+                         yolox-m
+                         yolox-l
+                         yolox-x
+```
+* --fuse: fuse conv and bn
+* -d: number of GPUs used for evaluation. DEFAULT: All GPUs available will be used.
+* -b: total batch size across on all GPUs
+
+To reproduce speed test, we use the following command:
+```shell
+python tools/eval.py -n  yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --fp16 --fuse
+                         yolox-m
+                         yolox-l
+                         yolox-x
+```
diff --git a/multimodal/YOLOX/docs/requirements-doc.txt b/multimodal/YOLOX/docs/requirements-doc.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3f721536bb13e4566bc68613deefdaa436b51b44
--- /dev/null
+++ b/multimodal/YOLOX/docs/requirements-doc.txt
@@ -0,0 +1,8 @@
+docutils==0.16
+# https://github.com/sphinx-doc/sphinx/commit/7acd3ada3f38076af7b2b5c9f3b60bb9c2587a3d
+sphinx==3.2.0
+recommonmark==0.6.0
+sphinx_rtd_theme
+omegaconf>=2.1.0.dev24
+hydra-core>=1.1.0.dev5
+sphinx-markdown-tables==0.0.15
diff --git a/multimodal/YOLOX/docs/train_custom_data.md b/multimodal/YOLOX/docs/train_custom_data.md
new file mode 100644
index 0000000000000000000000000000000000000000..ee97cc94ae5ffd606053c1adcd057c86283c8185
--- /dev/null
+++ b/multimodal/YOLOX/docs/train_custom_data.md
@@ -0,0 +1,131 @@
+# Train Custom Data
+
+This page explains how to train your own custom data with YOLOX.
+
+We take an example of fine-tuning YOLOX-S model on VOC dataset to give a more clear guide.
+
+## 0. Before you start
+Clone this repo and follow the [README](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/README.md) to install YOLOX.
+
+## 1. Create your own dataset
+**Step 1** Prepare your own dataset with images and labels first. For labeling images, you can use tools like [Labelme](https://github.com/wkentaro/labelme) or [CVAT](https://github.com/openvinotoolkit/cvat).
+
+**Step 2** Then, you should write the corresponding Dataset Class which can load images and labels through `__getitem__` method. We currently support COCO format and VOC format.
+
+You can also write the Dataset by your own. Let's take the [VOC](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/data/datasets/voc.py#L151) Dataset file for example:
+```python
+    @Dataset.resize_getitem
+    def __getitem__(self, index):
+        img, target, img_info, img_id = self.pull_item(index)
+
+        if self.preproc is not None:
+            img, target = self.preproc(img, target, self.input_dim)
+
+        return img, target, img_info, img_id
+```
+
+One more thing worth noting is that you should also implement [pull_item](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/data/datasets/voc.py#L129) and [load_anno](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/data/datasets/voc.py#L121) method for the `Mosiac` and `MixUp` augmentations.
+
+**Step 3** Prepare the evaluator. We currently have [COCO evaluator](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/evaluators/coco_evaluator.py) and [VOC evaluator](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/evaluators/voc_evaluator.py).
+If you have your own format data or evaluation metric, you can write your own evaluator.
+
+**Step 4** Put your dataset under `$YOLOX_DIR/datasets`, for VOC:
+
+```shell
+ln -s /path/to/your/VOCdevkit ./datasets/VOCdevkit
+```
+* The path "VOCdevkit" will be used in your exp file described in next section. Specifically, in `get_data_loader` and `get_eval_loader` function.
+
+✧✧✧ You can download the mini-coco128 dataset by the [link](https://drive.google.com/file/d/16N3u36ycNd70m23IM7vMuRQXejAJY9Fs/view?usp=sharing), and then unzip it to the `datasets` directory. The dataset has been converted from YOLO format to COCO format, and can be used directly as a dataset for testing whether the train environment can be runned successfully.
+
+## 2. Create your Exp file to control everything
+We put everything involved in a model to one single Exp file, including model setting, training setting, and testing setting.
+
+**A complete Exp file is at [yolox_base.py](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/exp/yolox_base.py).** It may be too long to write for every exp, but you can inherit the base Exp file and only overwrite the changed part.
+
+Let's take the [VOC Exp file](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/example/yolox_voc/yolox_voc_s.py) as an example.
+
+We select `YOLOX-S` model here, so we should change the network depth and width. VOC has only 20 classes, so we should also change the `num_classes`.
+
+These configs are changed in the `init()` method:
+```python
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.num_classes = 20
+        self.depth = 0.33
+        self.width = 0.50
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+```
+
+Besides, you should also overwrite the `dataset` and `evaluator`, prepared before training the model on your own data.
+
+Please see [get_data_loader](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/example/yolox_voc/yolox_voc_s.py#L20), [get_eval_loader](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/example/yolox_voc/yolox_voc_s.py#L82), and [get_evaluator](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/exps/example/yolox_voc/yolox_voc_s.py#L113) for more details.
+
+✧✧✧ You can also see the `exps/example/custom` directory for more details.
+
+## 3. Train
+Except special cases, we always recommend to use our [COCO pretrained weights](https://github.com/Megvii-BaseDetection/YOLOX/blob/main/README.md) for initializing the model.
+
+Once you get the Exp file and the COCO pretrained weights we provided, you can train your own model by the following below command:
+```bash
+python tools/train.py -f /path/to/your/Exp/file -d 8 -b 64 --fp16 -o -c /path/to/the/pretrained/weights [--cache]
+```
+* --cache: we now support RAM caching to speed up training! Make sure you have enough system RAM when adopting it. 
+
+or take the `YOLOX-S` VOC training for example:
+```bash
+python tools/train.py -f exps/example/yolox_voc/yolox_voc_s.py -d 8 -b 64 --fp16 -o -c /path/to/yolox_s.pth [--cache]
+```
+
+✧✧✧ For example:
+- If you download the [mini-coco128](https://drive.google.com/file/d/16N3u36ycNd70m23IM7vMuRQXejAJY9Fs/view?usp=sharing) and unzip it to the `datasets`, you can direct run the following training code.
+    ```bash
+    python tools/train.py -f exps/example/custom/yolox_s.py -d 8 -b 64 --fp16 -o -c /path/to/yolox_s.pth
+    ```
+
+(Don't worry for the different shape of detection head between the pretrained weights and your own model, we will handle it)
+
+## 4. Tips for Best Training Results
+
+As **YOLOX** is an anchor-free detector with only several hyper-parameters, most of the time good results can be obtained with no changes to the models or training settings.
+We thus always recommend you first train with all default training settings.
+
+If at first you don't get good results, there are steps you could consider to improve the model.
+
+**Model Selection** We provide `YOLOX-Nano`, `YOLOX-Tiny`, and `YOLOX-S` for mobile deployments, while `YOLOX-M`/`L`/`X` for cloud or high performance GPU deployments.
+
+If your deployment meets any compatibility issues. we recommend `YOLOX-DarkNet53`.
+
+**Training Configs** If your training overfits early, then you can reduce max\_epochs or decrease the base\_lr and min\_lr\_ratio in your Exp file:
+
+```python
+# --------------  training config --------------------- #
+    self.warmup_epochs = 5
+    self.max_epoch = 300
+    self.warmup_lr = 0
+    self.basic_lr_per_img = 0.01 / 64.0
+    self.scheduler = "yoloxwarmcos"
+    self.no_aug_epochs = 15
+    self.min_lr_ratio = 0.05
+    self.ema = True
+
+    self.weight_decay = 5e-4
+    self.momentum = 0.9
+```
+
+**Aug Configs** You may also change the degree of the augmentations.
+
+Generally, for small models, you should weak the aug, while for large models or small size of dataset, you may enchance the aug in your Exp file:
+```python
+# --------------- transform config ----------------- #
+    self.degrees = 10.0
+    self.translate = 0.1
+    self.scale = (0.1, 2)
+    self.mosaic_scale = (0.8, 1.6)
+    self.shear = 2.0
+    self.perspective = 0.0
+    self.enable_mixup = True
+```
+
+**Design your own detector** You may refer to our [Arxiv](https://arxiv.org/abs/2107.08430) paper for details and suggestions for designing your own detector.
diff --git a/multimodal/YOLOX/docs/updates_note.md b/multimodal/YOLOX/docs/updates_note.md
new file mode 100644
index 0000000000000000000000000000000000000000..f675f43fcc36130d3294ab2c95210a56fdfb5c8e
--- /dev/null
+++ b/multimodal/YOLOX/docs/updates_note.md
@@ -0,0 +1,55 @@
+
+# Updates notes
+
+## 【2021/08/19】
+
+* Support image caching for faster training, which requires large system RAM. 
+* Remove the dependence of apex and support torch amp training. 
+* Optimize the preprocessing for faster training 
+* Replace the older distort augmentation with new HSV aug for faster training and better performance. 
+
+### 2X Faster training
+
+We optimize the data preprocess and support image caching with `--cache` flag:
+
+```shell
+python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
+                         yolox-m
+                         yolox-l
+                         yolox-x
+```
+* -d: number of gpu devices
+* -b: total batch size, the recommended number for -b is num-gpu * 8
+* --fp16: mixed precision training
+* --cache: caching imgs into RAM to accelarate training, which need large system RAM.
+
+### Higher performance
+
+New models achieve **~1%** higher performance! See [Model_Zoo](model_zoo.md) for more details.
+
+### Support torch amp
+
+We now support torch.cuda.amp training and Apex is not used anymore.
+
+### Breaking changes
+
+We remove the normalization operation like -mean/std. This will make the old weights **incompatible**.
+
+If you still want to use old weights, you can add `--legacy' in demo and eval:
+
+```shell
+python tools/demo.py image -n yolox-s -c /path/to/your/yolox_s.pth --path assets/dog.jpg --conf 0.25 --nms 0.45 --tsize 640 --save_result --device [cpu/gpu] [--legacy]
+```
+
+and 
+
+```shell
+python tools/eval.py -n  yolox-s -c yolox_s.pth -b 64 -d 8 --conf 0.001 [--fp16] [--fuse] [--legacy]
+                         yolox-m
+                         yolox-l
+                         yolox-x
+```
+
+But for deployment demo, we don't support the old weights anymore. Users could checkout to YOLOX version 0.1.0 to use legacy weights for deployment
+
+
diff --git a/multimodal/YOLOX/exps/default/__init__.py b/multimodal/YOLOX/exps/default/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce9fae0677b11bdd96e516f4b0b8a3782daed1ec
--- /dev/null
+++ b/multimodal/YOLOX/exps/default/__init__.py
@@ -0,0 +1,3 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
diff --git a/multimodal/YOLOX/exps/default/yolov3.py b/multimodal/YOLOX/exps/default/yolov3.py
new file mode 100644
index 0000000000000000000000000000000000000000..c747f8ae9f42549a1dbd7f03d8ee80e235d6467a
--- /dev/null
+++ b/multimodal/YOLOX/exps/default/yolov3.py
@@ -0,0 +1,33 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+import torch.nn as nn
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 1.0
+        self.width = 1.0
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+    def get_model(self, sublinear=False):
+        def init_yolo(M):
+            for m in M.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eps = 1e-3
+                    m.momentum = 0.03
+        if "model" not in self.__dict__:
+            from yolox.models import YOLOX, YOLOFPN, YOLOXHead
+            backbone = YOLOFPN()
+            head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu")
+            self.model = YOLOX(backbone, head)
+        self.model.apply(init_yolo)
+        self.model.head.initialize_biases(1e-2)
+
+        return self.model
diff --git a/multimodal/YOLOX/exps/default/yolox_l.py b/multimodal/YOLOX/exps/default/yolox_l.py
new file mode 100644
index 0000000000000000000000000000000000000000..50833ca38c51fe9ac5e327d7c1c0561fb62249aa
--- /dev/null
+++ b/multimodal/YOLOX/exps/default/yolox_l.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 1.0
+        self.width = 1.0
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
diff --git a/multimodal/YOLOX/exps/default/yolox_m.py b/multimodal/YOLOX/exps/default/yolox_m.py
new file mode 100644
index 0000000000000000000000000000000000000000..9666a31177b9cc1c94978f9867aaceac8ddebce2
--- /dev/null
+++ b/multimodal/YOLOX/exps/default/yolox_m.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.67
+        self.width = 0.75
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
diff --git a/multimodal/YOLOX/exps/default/yolox_nano.py b/multimodal/YOLOX/exps/default/yolox_nano.py
new file mode 100644
index 0000000000000000000000000000000000000000..8955dd2a7748c900cab7dca11adf877cd2cf5abd
--- /dev/null
+++ b/multimodal/YOLOX/exps/default/yolox_nano.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+import torch.nn as nn
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.25
+        self.input_size = (416, 416)
+        self.random_size = (10, 20)
+        self.mosaic_scale = (0.5, 1.5)
+        self.test_size = (416, 416)
+        self.mosaic_prob = 0.5
+        self.enable_mixup = False
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+    def get_model(self, sublinear=False):
+
+        def init_yolo(M):
+            for m in M.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eps = 1e-3
+                    m.momentum = 0.03
+        if "model" not in self.__dict__:
+            from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
+            in_channels = [256, 512, 1024]
+            # NANO model use depthwise = True, which is main difference.
+            backbone = YOLOPAFPN(
+                self.depth, self.width, in_channels=in_channels,
+                act=self.act, depthwise=True,
+            )
+            head = YOLOXHead(
+                self.num_classes, self.width, in_channels=in_channels,
+                act=self.act, depthwise=True
+            )
+            self.model = YOLOX(backbone, head)
+
+        self.model.apply(init_yolo)
+        self.model.head.initialize_biases(1e-2)
+        return self.model
diff --git a/multimodal/YOLOX/exps/default/yolox_s.py b/multimodal/YOLOX/exps/default/yolox_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..abb6a8bbbe4fd1c6aff71596621aaeec2a6a15d8
--- /dev/null
+++ b/multimodal/YOLOX/exps/default/yolox_s.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.50
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
diff --git a/multimodal/YOLOX/exps/default/yolox_tiny.py b/multimodal/YOLOX/exps/default/yolox_tiny.py
new file mode 100644
index 0000000000000000000000000000000000000000..5220de2f2e6760d5c9a966d5dd397aad721fc60a
--- /dev/null
+++ b/multimodal/YOLOX/exps/default/yolox_tiny.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.375
+        self.input_size = (416, 416)
+        self.mosaic_scale = (0.5, 1.5)
+        self.random_size = (10, 20)
+        self.test_size = (416, 416)
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+        self.enable_mixup = False
diff --git a/multimodal/YOLOX/exps/default/yolox_x.py b/multimodal/YOLOX/exps/default/yolox_x.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac498a1fb91f597e9362c2b73a9a002cf31445fc
--- /dev/null
+++ b/multimodal/YOLOX/exps/default/yolox_x.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 1.33
+        self.width = 1.25
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
diff --git a/multimodal/YOLOX/exps/example/custom/nano.py b/multimodal/YOLOX/exps/example/custom/nano.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb10626dbd37ea0744eba8d8340302aa4ffccffb
--- /dev/null
+++ b/multimodal/YOLOX/exps/example/custom/nano.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+
+import torch.nn as nn
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.25
+        self.input_size = (416, 416)
+        self.mosaic_scale = (0.5, 1.5)
+        self.random_size = (10, 20)
+        self.test_size = (416, 416)
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+        self.enable_mixup = False
+
+        # Define yourself dataset path
+        self.data_dir = "datasets/coco128"
+        self.train_ann = "instances_train2017.json"
+        self.val_ann = "instances_val2017.json"
+
+        self.num_classes = 71
+
+    def get_model(self, sublinear=False):
+
+        def init_yolo(M):
+            for m in M.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eps = 1e-3
+                    m.momentum = 0.03
+        if "model" not in self.__dict__:
+            from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
+            in_channels = [256, 512, 1024]
+            # NANO model use depthwise = True, which is main difference.
+            backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, depthwise=True)
+            head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, depthwise=True)
+            self.model = YOLOX(backbone, head)
+
+        self.model.apply(init_yolo)
+        self.model.head.initialize_biases(1e-2)
+        return self.model
diff --git a/multimodal/YOLOX/exps/example/custom/yolox_s.py b/multimodal/YOLOX/exps/example/custom/yolox_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f0b0a5f76b63a993c24e3f33c69fd960144a42c
--- /dev/null
+++ b/multimodal/YOLOX/exps/example/custom/yolox_s.py
@@ -0,0 +1,25 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import os
+
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.depth = 0.33
+        self.width = 0.50
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+        # Define yourself dataset path
+        self.data_dir = "datasets/coco128"
+        self.train_ann = "instances_train2017.json"
+        self.val_ann = "instances_val2017.json"
+
+        self.num_classes = 71
+
+        self.max_epoch = 300
+        self.data_num_workers = 4
+        self.eval_interval = 1
diff --git a/multimodal/YOLOX/exps/example/yolox_voc/yolox_voc_s.py b/multimodal/YOLOX/exps/example/yolox_voc/yolox_voc_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..379ba9ac79adb4c0fa8677088f2c6eaafda38046
--- /dev/null
+++ b/multimodal/YOLOX/exps/example/yolox_voc/yolox_voc_s.py
@@ -0,0 +1,60 @@
+# encoding: utf-8
+import os
+
+from yolox.data import get_yolox_datadir
+from yolox.exp import Exp as MyExp
+
+
+class Exp(MyExp):
+    def __init__(self):
+        super(Exp, self).__init__()
+        self.num_classes = 20
+        self.depth = 0.33
+        self.width = 0.50
+        self.warmup_epochs = 1
+
+        # ---------- transform config ------------ #
+        self.mosaic_prob = 1.0
+        self.mixup_prob = 1.0
+        self.hsv_prob = 1.0
+        self.flip_prob = 0.5
+
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+    def get_dataset(self, cache: bool, cache_type: str = "ram"):
+        from yolox.data import VOCDetection, TrainTransform
+
+        return VOCDetection(
+            data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
+            image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
+            img_size=self.input_size,
+            preproc=TrainTransform(
+                max_labels=50,
+                flip_prob=self.flip_prob,
+                hsv_prob=self.hsv_prob),
+            cache=cache,
+            cache_type=cache_type,
+        )
+
+    def get_eval_dataset(self, **kwargs):
+        from yolox.data import VOCDetection, ValTransform
+        legacy = kwargs.get("legacy", False)
+
+        return VOCDetection(
+            data_dir=os.path.join(get_yolox_datadir(), "VOCdevkit"),
+            image_sets=[('2007', 'test')],
+            img_size=self.test_size,
+            preproc=ValTransform(legacy=legacy),
+        )
+
+    def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
+        from yolox.evaluators import VOCEvaluator
+
+        return VOCEvaluator(
+            dataloader=self.get_eval_loader(batch_size, is_distributed,
+                                            testdev=testdev, legacy=legacy),
+            img_size=self.test_size,
+            confthre=self.test_conf,
+            nmsthre=self.nmsthre,
+            num_classes=self.num_classes,
+        )
diff --git a/multimodal/YOLOX/hubconf.py b/multimodal/YOLOX/hubconf.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ff7f37fdd7efdd126e04f7ede3a9d066e74dde6
--- /dev/null
+++ b/multimodal/YOLOX/hubconf.py
@@ -0,0 +1,22 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+"""
+Usage example:
+    import torch
+    model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_s")
+    model = torch.hub.load("Megvii-BaseDetection/YOLOX", "yolox_custom",
+                           exp_path="exp.py", ckpt_path="ckpt.pth")
+"""
+dependencies = ["torch"]
+
+from yolox.models import (  # isort:skip  # noqa: F401, E402
+    yolox_tiny,
+    yolox_nano,
+    yolox_s,
+    yolox_m,
+    yolox_l,
+    yolox_x,
+    yolov3,
+    yolox_custom
+)
diff --git a/multimodal/YOLOX/requirements.txt b/multimodal/YOLOX/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..58115c285cd888cf71ada5e9fbc33f25ff3798df
--- /dev/null
+++ b/multimodal/YOLOX/requirements.txt
@@ -0,0 +1,12 @@
+# TODO: Update with exact module version
+numpy
+torch>=1.7
+opencv_python
+loguru
+tqdm
+torchvision
+thop
+ninja
+tabulate
+psutil
+tensorboard
diff --git a/multimodal/YOLOX/setup.cfg b/multimodal/YOLOX/setup.cfg
new file mode 100644
index 0000000000000000000000000000000000000000..9a277f407b2df1e25fe89bd635bb95c56e47c09a
--- /dev/null
+++ b/multimodal/YOLOX/setup.cfg
@@ -0,0 +1,18 @@
+[isort]
+line_length = 100
+multi_line_output = 3
+balanced_wrapping = True
+known_standard_library = setuptools
+known_third_party = tqdm,loguru,tabulate,psutil
+known_data_processing = cv2,numpy,scipy,PIL,matplotlib
+known_datasets = pycocotools
+known_deeplearning = torch,torchvision,caffe2,onnx,apex,timm,thop,torch2trt,tensorrt,openvino,onnxruntime
+known_myself = yolox
+sections = FUTURE,STDLIB,THIRDPARTY,data_processing,datasets,deeplearning,myself,FIRSTPARTY,LOCALFOLDER
+no_lines_before=STDLIB,THIRDPARTY,datasets
+default_section = FIRSTPARTY
+
+[flake8]
+max-line-length = 100
+max-complexity = 18
+exclude = __init__.py
diff --git a/multimodal/YOLOX/setup.py b/multimodal/YOLOX/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fec79764f284e49947e9b343b59fe3249fa04ed
--- /dev/null
+++ b/multimodal/YOLOX/setup.py
@@ -0,0 +1,88 @@
+#!/usr/bin/env python
+# Copyright (c) Megvii, Inc. and its affiliates. All Rights Reserved
+
+import re
+import setuptools
+import sys
+
+TORCH_AVAILABLE = True
+try:
+    import torch
+    from torch.utils import cpp_extension
+except ImportError:
+    TORCH_AVAILABLE = False
+    print("[WARNING] Unable to import torch, pre-compiling ops will be disabled.")
+
+
+def get_package_dir():
+    pkg_dir = {
+        "yolox.tools": "tools",
+        "yolox.exp.default": "exps/default",
+    }
+    return pkg_dir
+
+
+def get_install_requirements():
+    with open("requirements.txt", "r", encoding="utf-8") as f:
+        reqs = [x.strip() for x in f.read().splitlines()]
+    reqs = [x for x in reqs if not x.startswith("#")]
+    return reqs
+
+
+def get_yolox_version():
+    with open("yolox/__init__.py", "r") as f:
+        version = re.search(
+            r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
+            f.read(), re.MULTILINE
+        ).group(1)
+    return version
+
+
+def get_long_description():
+    with open("README.md", "r", encoding="utf-8") as f:
+        long_description = f.read()
+    return long_description
+
+
+def get_ext_modules():
+    ext_module = []
+    if sys.platform != "win32":  # pre-compile ops on linux
+        assert TORCH_AVAILABLE, "torch is required for pre-compiling ops, please install it first."
+        # if any other op is added, please also add it here
+        from yolox.layers import FastCOCOEvalOp
+        ext_module.append(FastCOCOEvalOp().build_op())
+    return ext_module
+
+
+def get_cmd_class():
+    cmdclass = {}
+    if TORCH_AVAILABLE:
+        cmdclass["build_ext"] = cpp_extension.BuildExtension
+    return cmdclass
+
+
+setuptools.setup(
+    name="yolox",
+    version=get_yolox_version(),
+    author="megvii basedet team",
+    url="https://github.com/Megvii-BaseDetection/YOLOX",
+    package_dir=get_package_dir(),
+    packages=setuptools.find_packages(exclude=("tests", "tools")) + list(get_package_dir().keys()),
+    python_requires=">=3.6",
+    install_requires=get_install_requirements(),
+    setup_requires=["wheel"],  # avoid building error when pip is not updated
+    long_description=get_long_description(),
+    long_description_content_type="text/markdown",
+    include_package_data=True,  # include files in MANIFEST.in
+    ext_modules=get_ext_modules(),
+    cmdclass=get_cmd_class(),
+    classifiers=[
+        "Programming Language :: Python :: 3", "Operating System :: OS Independent",
+        "License :: OSI Approved :: Apache Software License",
+    ],
+    project_urls={
+        "Documentation": "https://yolox.readthedocs.io",
+        "Source": "https://github.com/Megvii-BaseDetection/YOLOX",
+        "Tracker": "https://github.com/Megvii-BaseDetection/YOLOX/issues",
+    },
+)
diff --git a/multimodal/YOLOX/tests/__init__.py b/multimodal/YOLOX/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c53f601b3cf8436e1709a33363b218bc4f5ef512
--- /dev/null
+++ b/multimodal/YOLOX/tests/__init__.py
@@ -0,0 +1,2 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
diff --git a/multimodal/YOLOX/tests/utils/test_model_utils.py b/multimodal/YOLOX/tests/utils/test_model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..abfc3446f06974998c8ab25b5ded52e1327e2363
--- /dev/null
+++ b/multimodal/YOLOX/tests/utils/test_model_utils.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import unittest
+
+import torch
+from torch import nn
+
+from yolox.utils import adjust_status, freeze_module
+from yolox.exp import get_exp
+
+
+class TestModelUtils(unittest.TestCase):
+
+    def setUp(self):
+        self.model: nn.Module = get_exp(exp_name="yolox-s").get_model()
+
+    def test_model_state_adjust_status(self):
+        data = torch.ones(1, 10, 10, 10)
+        # use bn since bn changes state during train/val
+        model = nn.BatchNorm2d(10)
+        prev_state = model.state_dict()
+
+        modes = [False, True]
+        results = [True, False]
+
+        # test under train/eval mode
+        for mode, result in zip(modes, results):
+            with adjust_status(model, training=mode):
+                model(data)
+            model_state = model.state_dict()
+            self.assertTrue(len(model_state) == len(prev_state))
+            self.assertEqual(
+                result,
+                all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
+            )
+
+        # test recurrsive context case
+        prev_state = model.state_dict()
+        with adjust_status(model, training=False):
+            with adjust_status(model, training=False):
+                model(data)
+        model_state = model.state_dict()
+        self.assertTrue(len(model_state) == len(prev_state))
+        self.assertTrue(
+            all([torch.allclose(v, model_state[k]) for k, v in prev_state.items()])
+        )
+
+    def test_model_effect_adjust_status(self):
+        # test context effect
+        self.model.train()
+        with adjust_status(self.model, training=False):
+            for module in self.model.modules():
+                self.assertFalse(module.training)
+        # all training after exit
+        for module in self.model.modules():
+            self.assertTrue(module.training)
+
+        # only backbone set to eval
+        self.model.backbone.eval()
+        with adjust_status(self.model, training=False):
+            for module in self.model.modules():
+                self.assertFalse(module.training)
+
+        for name, module in self.model.named_modules():
+            if "backbone" in name:
+                self.assertFalse(module.training)
+            else:
+                self.assertTrue(module.training)
+
+    def test_freeze_module(self):
+        model = nn.Sequential(
+            nn.Conv2d(3, 10, 1),
+            nn.BatchNorm2d(10),
+            nn.ReLU(),
+        )
+        data = torch.rand(1, 3, 10, 10)
+        model.train()
+        assert isinstance(model[1], nn.BatchNorm2d)
+        before_states = model[1].state_dict()
+        freeze_module(model[1])
+        model(data)
+        after_states = model[1].state_dict()
+        self.assertTrue(
+            all([torch.allclose(v, after_states[k]) for k, v in before_states.items()])
+        )
+
+        # yolox test
+        self.model.train()
+        for module in self.model.modules():
+            self.assertTrue(module.training)
+
+        freeze_module(self.model, "backbone")
+        for module in self.model.backbone.modules():
+            self.assertFalse(module.training)
+        for p in self.model.backbone.parameters():
+            self.assertFalse(p.requires_grad)
+
+        for module in self.model.head.modules():
+            self.assertTrue(module.training)
+        for p in self.model.head.parameters():
+            self.assertTrue(p.requires_grad)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/multimodal/YOLOX/tools/__init__.py b/multimodal/YOLOX/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce9fae0677b11bdd96e516f4b0b8a3782daed1ec
--- /dev/null
+++ b/multimodal/YOLOX/tools/__init__.py
@@ -0,0 +1,3 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
diff --git a/multimodal/YOLOX/tools/demo.py b/multimodal/YOLOX/tools/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b16598d5f4f355a4884341bd1188052b9384018b
--- /dev/null
+++ b/multimodal/YOLOX/tools/demo.py
@@ -0,0 +1,320 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+import time
+from loguru import logger
+
+import cv2
+
+import torch
+
+from yolox.data.data_augment import ValTransform
+from yolox.data.datasets import COCO_CLASSES
+from yolox.exp import get_exp
+from yolox.utils import fuse_model, get_model_info, postprocess, vis
+
+IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX Demo!")
+    parser.add_argument(
+        "demo", default="image", help="demo type, eg. image, video and webcam"
+    )
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+
+    parser.add_argument(
+        "--path", default="./assets/dog.jpg", help="path to images or video"
+    )
+    parser.add_argument("--camid", type=int, default=0, help="webcam demo camera id")
+    parser.add_argument(
+        "--save_result",
+        action="store_true",
+        help="whether to save the inference result of image/video",
+    )
+
+    # exp file
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="please input your experiment description file",
+    )
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
+    parser.add_argument(
+        "--device",
+        default="cpu",
+        type=str,
+        help="device to run our model, can either be cpu or gpu",
+    )
+    parser.add_argument("--conf", default=0.3, type=float, help="test conf")
+    parser.add_argument("--nms", default=0.3, type=float, help="test nms threshold")
+    parser.add_argument("--tsize", default=None, type=int, help="test img size")
+    parser.add_argument(
+        "--fp16",
+        dest="fp16",
+        default=False,
+        action="store_true",
+        help="Adopting mix precision evaluating.",
+    )
+    parser.add_argument(
+        "--legacy",
+        dest="legacy",
+        default=False,
+        action="store_true",
+        help="To be compatible with older versions",
+    )
+    parser.add_argument(
+        "--fuse",
+        dest="fuse",
+        default=False,
+        action="store_true",
+        help="Fuse conv and bn for testing.",
+    )
+    parser.add_argument(
+        "--trt",
+        dest="trt",
+        default=False,
+        action="store_true",
+        help="Using TensorRT model for testing.",
+    )
+    return parser
+
+
+def get_image_list(path):
+    image_names = []
+    for maindir, subdir, file_name_list in os.walk(path):
+        for filename in file_name_list:
+            apath = os.path.join(maindir, filename)
+            ext = os.path.splitext(apath)[1]
+            if ext in IMAGE_EXT:
+                image_names.append(apath)
+    return image_names
+
+
+class Predictor(object):
+    def __init__(
+        self,
+        model,
+        exp,
+        cls_names=COCO_CLASSES,
+        trt_file=None,
+        decoder=None,
+        device="cpu",
+        fp16=False,
+        legacy=False,
+    ):
+        self.model = model
+        self.cls_names = cls_names
+        self.decoder = decoder
+        self.num_classes = exp.num_classes
+        self.confthre = exp.test_conf
+        self.nmsthre = exp.nmsthre
+        self.test_size = exp.test_size
+        self.device = device
+        self.fp16 = fp16
+        self.preproc = ValTransform(legacy=legacy)
+        if trt_file is not None:
+            from torch2trt import TRTModule
+
+            model_trt = TRTModule()
+            model_trt.load_state_dict(torch.load(trt_file))
+
+            x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
+            self.model(x)
+            self.model = model_trt
+
+    def inference(self, img):
+        img_info = {"id": 0}
+        if isinstance(img, str):
+            img_info["file_name"] = os.path.basename(img)
+            img = cv2.imread(img)
+        else:
+            img_info["file_name"] = None
+
+        height, width = img.shape[:2]
+        img_info["height"] = height
+        img_info["width"] = width
+        img_info["raw_img"] = img
+
+        ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
+        img_info["ratio"] = ratio
+
+        img, _ = self.preproc(img, None, self.test_size)
+        img = torch.from_numpy(img).unsqueeze(0)
+        img = img.float()
+        if self.device == "gpu":
+            img = img.cuda()
+            if self.fp16:
+                img = img.half()  # to FP16
+
+        with torch.no_grad():
+            t0 = time.time()
+            outputs = self.model(img)
+            if self.decoder is not None:
+                outputs = self.decoder(outputs, dtype=outputs.type())
+            outputs = postprocess(
+                outputs, self.num_classes, self.confthre,
+                self.nmsthre, class_agnostic=True
+            )
+            logger.info("Infer time: {:.4f}s".format(time.time() - t0))
+        return outputs, img_info
+
+    def visual(self, output, img_info, cls_conf=0.35):
+        ratio = img_info["ratio"]
+        img = img_info["raw_img"]
+        if output is None:
+            return img
+        output = output.cpu()
+
+        bboxes = output[:, 0:4]
+
+        # preprocessing: resize
+        bboxes /= ratio
+
+        cls = output[:, 6]
+        scores = output[:, 4] * output[:, 5]
+
+        vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
+        return vis_res
+
+
+def image_demo(predictor, vis_folder, path, current_time, save_result):
+    if os.path.isdir(path):
+        files = get_image_list(path)
+    else:
+        files = [path]
+    files.sort()
+    for image_name in files:
+        outputs, img_info = predictor.inference(image_name)
+        result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
+        if save_result:
+            save_folder = os.path.join(
+                vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
+            )
+            os.makedirs(save_folder, exist_ok=True)
+            save_file_name = os.path.join(save_folder, os.path.basename(image_name))
+            logger.info("Saving detection result in {}".format(save_file_name))
+            cv2.imwrite(save_file_name, result_image)
+        ch = cv2.waitKey(0)
+        if ch == 27 or ch == ord("q") or ch == ord("Q"):
+            break
+
+
+def imageflow_demo(predictor, vis_folder, current_time, args):
+    cap = cv2.VideoCapture(args.path if args.demo == "video" else args.camid)
+    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
+    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
+    fps = cap.get(cv2.CAP_PROP_FPS)
+    if args.save_result:
+        save_folder = os.path.join(
+            vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
+        )
+        os.makedirs(save_folder, exist_ok=True)
+        if args.demo == "video":
+            save_path = os.path.join(save_folder, os.path.basename(args.path))
+        else:
+            save_path = os.path.join(save_folder, "camera.mp4")
+        logger.info(f"video save_path is {save_path}")
+        vid_writer = cv2.VideoWriter(
+            save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
+        )
+    while True:
+        ret_val, frame = cap.read()
+        if ret_val:
+            outputs, img_info = predictor.inference(frame)
+            result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
+            if args.save_result:
+                vid_writer.write(result_frame)
+            else:
+                cv2.namedWindow("yolox", cv2.WINDOW_NORMAL)
+                cv2.imshow("yolox", result_frame)
+            ch = cv2.waitKey(1)
+            if ch == 27 or ch == ord("q") or ch == ord("Q"):
+                break
+        else:
+            break
+
+
+def main(exp, args):
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    file_name = os.path.join(exp.output_dir, args.experiment_name)
+    os.makedirs(file_name, exist_ok=True)
+
+    vis_folder = None
+    if args.save_result:
+        vis_folder = os.path.join(file_name, "vis_res")
+        os.makedirs(vis_folder, exist_ok=True)
+
+    if args.trt:
+        args.device = "gpu"
+
+    logger.info("Args: {}".format(args))
+
+    if args.conf is not None:
+        exp.test_conf = args.conf
+    if args.nms is not None:
+        exp.nmsthre = args.nms
+    if args.tsize is not None:
+        exp.test_size = (args.tsize, args.tsize)
+
+    model = exp.get_model()
+    logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
+
+    if args.device == "gpu":
+        model.cuda()
+        if args.fp16:
+            model.half()  # to FP16
+    model.eval()
+
+    if not args.trt:
+        if args.ckpt is None:
+            ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+        else:
+            ckpt_file = args.ckpt
+        logger.info("loading checkpoint")
+        ckpt = torch.load(ckpt_file, map_location="cpu")
+        # load the model state dict
+        model.load_state_dict(ckpt["model"])
+        logger.info("loaded checkpoint done.")
+
+    if args.fuse:
+        logger.info("\tFusing model...")
+        model = fuse_model(model)
+
+    if args.trt:
+        assert not args.fuse, "TensorRT model is not support model fusing!"
+        trt_file = os.path.join(file_name, "model_trt.pth")
+        assert os.path.exists(
+            trt_file
+        ), "TensorRT model is not found!\n Run python3 tools/trt.py first!"
+        model.head.decode_in_inference = False
+        decoder = model.head.decode_outputs
+        logger.info("Using TensorRT to inference")
+    else:
+        trt_file = None
+        decoder = None
+
+    predictor = Predictor(
+        model, exp, COCO_CLASSES, trt_file, decoder,
+        args.device, args.fp16, args.legacy,
+    )
+    current_time = time.localtime()
+    if args.demo == "image":
+        image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
+    elif args.demo == "video" or args.demo == "webcam":
+        imageflow_demo(predictor, vis_folder, current_time, args)
+
+
+if __name__ == "__main__":
+    args = make_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+
+    main(exp, args)
diff --git a/multimodal/YOLOX/tools/eval.py b/multimodal/YOLOX/tools/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..83ad76be884e195e01a14fb376371b5531af14c5
--- /dev/null
+++ b/multimodal/YOLOX/tools/eval.py
@@ -0,0 +1,220 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+import random
+import warnings
+from loguru import logger
+
+import torch
+import torch.backends.cudnn as cudnn
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from yolox.core import launch
+from yolox.exp import get_exp
+from yolox.utils import (
+    configure_module,
+    configure_nccl,
+    fuse_model,
+    get_local_rank,
+    get_model_info,
+    setup_logger
+)
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX Eval")
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+
+    # distributed
+    parser.add_argument(
+        "--dist-backend", default="nccl", type=str, help="distributed backend"
+    )
+    parser.add_argument(
+        "--dist-url",
+        default=None,
+        type=str,
+        help="url used to set up distributed training",
+    )
+    parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
+    parser.add_argument(
+        "-d", "--devices", default=None, type=int, help="device for training"
+    )
+    parser.add_argument(
+        "--num_machines", default=1, type=int, help="num of node for training"
+    )
+    parser.add_argument(
+        "--machine_rank", default=0, type=int, help="node rank for multi-node training"
+    )
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="please input your experiment description file",
+    )
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
+    parser.add_argument("--conf", default=None, type=float, help="test conf")
+    parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
+    parser.add_argument("--tsize", default=None, type=int, help="test img size")
+    parser.add_argument("--seed", default=None, type=int, help="eval seed")
+    parser.add_argument(
+        "--fp16",
+        dest="fp16",
+        default=False,
+        action="store_true",
+        help="Adopting mix precision evaluating.",
+    )
+    parser.add_argument(
+        "--fuse",
+        dest="fuse",
+        default=False,
+        action="store_true",
+        help="Fuse conv and bn for testing.",
+    )
+    parser.add_argument(
+        "--trt",
+        dest="trt",
+        default=False,
+        action="store_true",
+        help="Using TensorRT model for testing.",
+    )
+    parser.add_argument(
+        "--legacy",
+        dest="legacy",
+        default=False,
+        action="store_true",
+        help="To be compatible with older versions",
+    )
+    parser.add_argument(
+        "--test",
+        dest="test",
+        default=False,
+        action="store_true",
+        help="Evaluating on test-dev set.",
+    )
+    parser.add_argument(
+        "--speed",
+        dest="speed",
+        default=False,
+        action="store_true",
+        help="speed test only.",
+    )
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+    return parser
+
+
+@logger.catch
+def main(exp, args, num_gpu):
+    if args.seed is not None:
+        random.seed(args.seed)
+        torch.manual_seed(args.seed)
+        cudnn.deterministic = True
+        warnings.warn(
+            "You have chosen to seed testing. This will turn on the CUDNN deterministic setting, "
+        )
+
+    is_distributed = num_gpu > 1
+
+    # set environment variables for distributed training
+    configure_nccl()
+    cudnn.benchmark = True
+
+    rank = get_local_rank()
+
+    file_name = os.path.join(exp.output_dir, args.experiment_name)
+
+    if rank == 0:
+        os.makedirs(file_name, exist_ok=True)
+
+    setup_logger(file_name, distributed_rank=rank, filename="val_log.txt", mode="a")
+    logger.info("Args: {}".format(args))
+
+    if args.conf is not None:
+        exp.test_conf = args.conf
+    if args.nms is not None:
+        exp.nmsthre = args.nms
+    if args.tsize is not None:
+        exp.test_size = (args.tsize, args.tsize)
+
+    model = exp.get_model()
+    logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
+    logger.info("Model Structure:\n{}".format(str(model)))
+
+    evaluator = exp.get_evaluator(args.batch_size, is_distributed, args.test, args.legacy)
+    evaluator.per_class_AP = True
+    evaluator.per_class_AR = True
+
+    torch.cuda.set_device(rank)
+    model.cuda(rank)
+    model.eval()
+
+    if not args.speed and not args.trt:
+        if args.ckpt is None:
+            ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+        else:
+            ckpt_file = args.ckpt
+        logger.info("loading checkpoint from {}".format(ckpt_file))
+        loc = "cuda:{}".format(rank)
+        ckpt = torch.load(ckpt_file, map_location=loc)
+        model.load_state_dict(ckpt["model"])
+        logger.info("loaded checkpoint done.")
+
+    if is_distributed:
+        model = DDP(model, device_ids=[rank])
+
+    if args.fuse:
+        logger.info("\tFusing model...")
+        model = fuse_model(model)
+
+    if args.trt:
+        assert (
+            not args.fuse and not is_distributed and args.batch_size == 1
+        ), "TensorRT model is not support model fusing and distributed inferencing!"
+        trt_file = os.path.join(file_name, "model_trt.pth")
+        assert os.path.exists(
+            trt_file
+        ), "TensorRT model is not found!\n Run tools/trt.py first!"
+        model.head.decode_in_inference = False
+        decoder = model.head.decode_outputs
+    else:
+        trt_file = None
+        decoder = None
+
+    # start evaluate
+    *_, summary = evaluator.evaluate(
+        model, is_distributed, args.fp16, trt_file, decoder, exp.test_size
+    )
+    logger.info("\n" + summary)
+
+
+if __name__ == "__main__":
+    configure_module()
+    args = make_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
+    assert num_gpu <= torch.cuda.device_count()
+
+    dist_url = "auto" if args.dist_url is None else args.dist_url
+    launch(
+        main,
+        num_gpu,
+        args.num_machines,
+        args.machine_rank,
+        backend=args.dist_backend,
+        dist_url=dist_url,
+        args=(exp, args, num_gpu),
+    )
diff --git a/multimodal/YOLOX/tools/export_onnx.py b/multimodal/YOLOX/tools/export_onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..8703166a4ee487d2d4b713b42c6f8c55879281db
--- /dev/null
+++ b/multimodal/YOLOX/tools/export_onnx.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+from loguru import logger
+
+import torch
+from torch import nn
+
+from yolox.exp import get_exp
+from yolox.models.network_blocks import SiLU
+from yolox.utils import replace_module
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX onnx deploy")
+    parser.add_argument(
+        "--output-name", type=str, default="yolox.onnx", help="output name of models"
+    )
+    parser.add_argument(
+        "--input", default="images", type=str, help="input node name of onnx model"
+    )
+    parser.add_argument(
+        "--output", default="output", type=str, help="output node name of onnx model"
+    )
+    parser.add_argument(
+        "-o", "--opset", default=11, type=int, help="onnx opset version"
+    )
+    parser.add_argument("--batch-size", type=int, default=1, help="batch size")
+    parser.add_argument(
+        "--dynamic", action="store_true", help="whether the input shape should be dynamic or not"
+    )
+    parser.add_argument("--no-onnxsim", action="store_true", help="use onnxsim or not")
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="experiment description file",
+    )
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+    parser.add_argument(
+        "--decode_in_inference",
+        action="store_true",
+        help="decode in inference or not"
+    )
+
+    return parser
+
+
+@logger.catch
+def main():
+    args = make_parser().parse_args()
+    logger.info("args value: {}".format(args))
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    model = exp.get_model()
+    if args.ckpt is None:
+        file_name = os.path.join(exp.output_dir, args.experiment_name)
+        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+    else:
+        ckpt_file = args.ckpt
+
+    # load the model state dict
+    ckpt = torch.load(ckpt_file, map_location="cpu")
+
+    model.eval()
+    if "model" in ckpt:
+        ckpt = ckpt["model"]
+    model.load_state_dict(ckpt)
+    model = replace_module(model, nn.SiLU, SiLU)
+    model.head.decode_in_inference = args.decode_in_inference
+
+    logger.info("loading checkpoint done.")
+    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
+
+    torch.onnx._export(
+        model,
+        dummy_input,
+        args.output_name,
+        input_names=[args.input],
+        output_names=[args.output],
+        dynamic_axes={args.input: {0: 'batch'},
+                      args.output: {0: 'batch'}} if args.dynamic else None,
+        opset_version=args.opset,
+    )
+    logger.info("generated onnx model named {}".format(args.output_name))
+
+    if not args.no_onnxsim:
+        import onnx
+        from onnxsim import simplify
+
+        # use onnx-simplifier to reduce reduent model.
+        onnx_model = onnx.load(args.output_name)
+        model_simp, check = simplify(onnx_model)
+        assert check, "Simplified ONNX model could not be validated"
+        onnx.save(model_simp, args.output_name)
+        logger.info("generated simplified onnx model named {}".format(args.output_name))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/YOLOX/tools/export_torchscript.py b/multimodal/YOLOX/tools/export_torchscript.py
new file mode 100644
index 0000000000000000000000000000000000000000..16a563bc56fe7c61475aec31ab5f2b604398cda9
--- /dev/null
+++ b/multimodal/YOLOX/tools/export_torchscript.py
@@ -0,0 +1,80 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+from loguru import logger
+
+import torch
+
+from yolox.exp import get_exp
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX torchscript deploy")
+    parser.add_argument(
+        "--output-name", type=str, default="yolox.torchscript.pt", help="output name of models"
+    )
+    parser.add_argument("--batch-size", type=int, default=1, help="batch size")
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="experiment description file",
+    )
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
+    parser.add_argument(
+        "--decode_in_inference",
+        action="store_true",
+        help="decode in inference or not"
+    )
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+
+    return parser
+
+
+@logger.catch
+def main():
+    args = make_parser().parse_args()
+    logger.info("args value: {}".format(args))
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    model = exp.get_model()
+    if args.ckpt is None:
+        file_name = os.path.join(exp.output_dir, args.experiment_name)
+        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+    else:
+        ckpt_file = args.ckpt
+
+    # load the model state dict
+    ckpt = torch.load(ckpt_file, map_location="cpu")
+
+    model.eval()
+    if "model" in ckpt:
+        ckpt = ckpt["model"]
+    model.load_state_dict(ckpt)
+    model.head.decode_in_inference = args.decode_in_inference
+
+    logger.info("loading checkpoint done.")
+    dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
+
+    mod = torch.jit.trace(model, dummy_input)
+    mod.save(args.output_name)
+    logger.info("generated torchscript model named {}".format(args.output_name))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/YOLOX/tools/train.py b/multimodal/YOLOX/tools/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..d57f420aebc81b3de24584ad19dacf5dc9ae3279
--- /dev/null
+++ b/multimodal/YOLOX/tools/train.py
@@ -0,0 +1,146 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import random
+import warnings
+from loguru import logger
+
+import torch
+import torch.backends.cudnn as cudnn
+
+from yolox.core import launch
+from yolox.exp import Exp, check_exp_value, get_exp
+from yolox.utils import configure_module, configure_nccl, configure_omp, get_num_devices
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX train parser")
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+
+    # distributed
+    parser.add_argument(
+        "--dist-backend", default="nccl", type=str, help="distributed backend"
+    )
+    parser.add_argument(
+        "--dist-url",
+        default=None,
+        type=str,
+        help="url used to set up distributed training",
+    )
+    parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
+    parser.add_argument(
+        "-d", "--devices", default=None, type=int, help="device for training"
+    )
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="plz input your experiment description file",
+    )
+    parser.add_argument(
+        "--resume", default=False, action="store_true", help="resume training"
+    )
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="checkpoint file")
+    parser.add_argument(
+        "-e",
+        "--start_epoch",
+        default=None,
+        type=int,
+        help="resume training start epoch",
+    )
+    parser.add_argument(
+        "--num_machines", default=1, type=int, help="num of node for training"
+    )
+    parser.add_argument(
+        "--machine_rank", default=0, type=int, help="node rank for multi-node training"
+    )
+    parser.add_argument(
+        "--fp16",
+        dest="fp16",
+        default=False,
+        action="store_true",
+        help="Adopting mix precision training.",
+    )
+    parser.add_argument(
+        "--cache",
+        type=str,
+        nargs="?",
+        const="ram",
+        help="Caching imgs to ram/disk for fast training.",
+    )
+    parser.add_argument(
+        "-o",
+        "--occupy",
+        dest="occupy",
+        default=False,
+        action="store_true",
+        help="occupy GPU memory first for training.",
+    )
+    parser.add_argument(
+        "-l",
+        "--logger",
+        type=str,
+        help="Logger to be used for metrics. \
+        Implemented loggers include `tensorboard` and `wandb`.",
+        default="tensorboard"
+    )
+    parser.add_argument(
+        "opts",
+        help="Modify config options using the command-line",
+        default=None,
+        nargs=argparse.REMAINDER,
+    )
+    return parser
+
+
+@logger.catch
+def main(exp: Exp, args):
+    if exp.seed is not None:
+        random.seed(exp.seed)
+        torch.manual_seed(exp.seed)
+        cudnn.deterministic = True
+        warnings.warn(
+            "You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
+            "which can slow down your training considerably! You may see unexpected behavior "
+            "when restarting from checkpoints."
+        )
+
+    # set environment variables for distributed training
+    configure_nccl()
+    configure_omp()
+    cudnn.benchmark = True
+
+    trainer = exp.get_trainer(args)
+    trainer.train()
+
+
+if __name__ == "__main__":
+    configure_module()
+    args = make_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+    check_exp_value(exp)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    num_gpu = get_num_devices() if args.devices is None else args.devices
+    assert num_gpu <= get_num_devices()
+
+    if args.cache is not None:
+        exp.dataset = exp.get_dataset(cache=True, cache_type=args.cache)
+
+    dist_url = "auto" if args.dist_url is None else args.dist_url
+    launch(
+        main,
+        num_gpu,
+        args.num_machines,
+        args.machine_rank,
+        backend=args.dist_backend,
+        dist_url=dist_url,
+        args=(exp, args),
+    )
diff --git a/multimodal/YOLOX/tools/trt.py b/multimodal/YOLOX/tools/trt.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2f6cee5c66038c126b23c94f1d935da04019b35
--- /dev/null
+++ b/multimodal/YOLOX/tools/trt.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import argparse
+import os
+import shutil
+from loguru import logger
+
+import tensorrt as trt
+import torch
+from torch2trt import torch2trt
+
+from yolox.exp import get_exp
+
+
+def make_parser():
+    parser = argparse.ArgumentParser("YOLOX ncnn deploy")
+    parser.add_argument("-expn", "--experiment-name", type=str, default=None)
+    parser.add_argument("-n", "--name", type=str, default=None, help="model name")
+
+    parser.add_argument(
+        "-f",
+        "--exp_file",
+        default=None,
+        type=str,
+        help="please input your experiment description file",
+    )
+    parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
+    parser.add_argument(
+        "-w", '--workspace', type=int, default=32, help='max workspace size in detect'
+    )
+    parser.add_argument("-b", '--batch', type=int, default=1, help='max batch size in detect')
+    return parser
+
+
+@logger.catch
+@torch.no_grad()
+def main():
+    args = make_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    model = exp.get_model()
+    file_name = os.path.join(exp.output_dir, args.experiment_name)
+    os.makedirs(file_name, exist_ok=True)
+    if args.ckpt is None:
+        ckpt_file = os.path.join(file_name, "best_ckpt.pth")
+    else:
+        ckpt_file = args.ckpt
+
+    ckpt = torch.load(ckpt_file, map_location="cpu")
+    # load the model state dict
+
+    model.load_state_dict(ckpt["model"])
+    logger.info("loaded checkpoint done.")
+    model.eval()
+    model.cuda()
+    model.head.decode_in_inference = False
+    x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
+    model_trt = torch2trt(
+        model,
+        [x],
+        fp16_mode=True,
+        log_level=trt.Logger.INFO,
+        max_workspace_size=(1 << args.workspace),
+        max_batch_size=args.batch,
+    )
+    torch.save(model_trt.state_dict(), os.path.join(file_name, "model_trt.pth"))
+    logger.info("Converted TensorRT model done.")
+    engine_file = os.path.join(file_name, "model_trt.engine")
+    engine_file_demo = os.path.join("demo", "TensorRT", "cpp", "model_trt.engine")
+    with open(engine_file, "wb") as f:
+        f.write(model_trt.engine.serialize())
+
+    shutil.copyfile(engine_file, engine_file_demo)
+
+    logger.info("Converted TensorRT model engine file is saved for C++ inference.")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/YOLOX/tools/visualize_assign.py b/multimodal/YOLOX/tools/visualize_assign.py
new file mode 100644
index 0000000000000000000000000000000000000000..e75a5586b7c878225327fc9046e2d4f21f182bb8
--- /dev/null
+++ b/multimodal/YOLOX/tools/visualize_assign.py
@@ -0,0 +1,93 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+import sys
+import random
+import time
+import warnings
+from loguru import logger
+
+import torch
+import torch.backends.cudnn as cudnn
+
+from yolox.exp import Exp, get_exp
+from yolox.core import Trainer
+from yolox.utils import configure_module, configure_omp
+from yolox.tools.train import make_parser
+
+
+class AssignVisualizer(Trainer):
+
+    def __init__(self, exp: Exp, args):
+        super().__init__(exp, args)
+        self.batch_cnt = 0
+        self.vis_dir = os.path.join(self.file_name, "vis")
+        os.makedirs(self.vis_dir, exist_ok=True)
+
+    def train_one_iter(self):
+        iter_start_time = time.time()
+
+        inps, targets = self.prefetcher.next()
+        inps = inps.to(self.data_type)
+        targets = targets.to(self.data_type)
+        targets.requires_grad = False
+        inps, targets = self.exp.preprocess(inps, targets, self.input_size)
+        data_end_time = time.time()
+
+        with torch.cuda.amp.autocast(enabled=self.amp_training):
+            path_prefix = os.path.join(self.vis_dir, f"assign_vis_{self.batch_cnt}_")
+            self.model.visualize(inps, targets, path_prefix)
+
+        if self.use_model_ema:
+            self.ema_model.update(self.model)
+
+        iter_end_time = time.time()
+        self.meter.update(
+            iter_time=iter_end_time - iter_start_time,
+            data_time=data_end_time - iter_start_time,
+        )
+        self.batch_cnt += 1
+        if self.batch_cnt >= self.args.max_batch:
+            sys.exit(0)
+
+    def after_train(self):
+        logger.info("Finish visualize assignment, exit...")
+
+
+def assign_vis_parser():
+    parser = make_parser()
+    parser.add_argument("--max-batch", type=int, default=1, help="max batch of images to visualize")
+    return parser
+
+
+@logger.catch
+def main(exp: Exp, args):
+    if exp.seed is not None:
+        random.seed(exp.seed)
+        torch.manual_seed(exp.seed)
+        cudnn.deterministic = True
+        warnings.warn(
+            "You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
+            "which can slow down your training considerably! You may see unexpected behavior "
+            "when restarting from checkpoints."
+        )
+
+    # set environment variables for distributed training
+    configure_omp()
+    cudnn.benchmark = True
+
+    visualizer = AssignVisualizer(exp, args)
+    visualizer.train()
+
+
+if __name__ == "__main__":
+    configure_module()
+    args = assign_vis_parser().parse_args()
+    exp = get_exp(args.exp_file, args.name)
+    exp.merge(args.opts)
+
+    if not args.experiment_name:
+        args.experiment_name = exp.exp_name
+
+    main(exp, args)
diff --git a/multimodal/YOLOX/yolox/__init__.py b/multimodal/YOLOX/yolox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c2c297ccde99381f96c6f36d7c2854a7418c161
--- /dev/null
+++ b/multimodal/YOLOX/yolox/__init__.py
@@ -0,0 +1,4 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+__version__ = "0.3.0"
diff --git a/multimodal/YOLOX/yolox/core/__init__.py b/multimodal/YOLOX/yolox/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2379c704ec6320066cbb45a6b8dacca548662a0
--- /dev/null
+++ b/multimodal/YOLOX/yolox/core/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+from .launch import launch
+from .trainer import Trainer
diff --git a/multimodal/YOLOX/yolox/core/launch.py b/multimodal/YOLOX/yolox/core/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f8eec61e379f7a4179536742c16609d240b55d6
--- /dev/null
+++ b/multimodal/YOLOX/yolox/core/launch.py
@@ -0,0 +1,147 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Code are based on
+# https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import sys
+from datetime import timedelta
+from loguru import logger
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+import yolox.utils.dist as comm
+
+__all__ = ["launch"]
+
+
+DEFAULT_TIMEOUT = timedelta(minutes=30)
+
+
+def _find_free_port():
+    """
+    Find an available port of current machine / node.
+    """
+    import socket
+
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    # Binding to port 0 will cause the OS to find an available port for us
+    sock.bind(("", 0))
+    port = sock.getsockname()[1]
+    sock.close()
+    # NOTE: there is still a chance the port could be taken by other processes.
+    return port
+
+
+def launch(
+    main_func,
+    num_gpus_per_machine,
+    num_machines=1,
+    machine_rank=0,
+    backend="nccl",
+    dist_url=None,
+    args=(),
+    timeout=DEFAULT_TIMEOUT,
+):
+    """
+    Args:
+        main_func: a function that will be called by `main_func(*args)`
+        num_machines (int): the total number of machines
+        machine_rank (int): the rank of this machine (one per machine)
+        dist_url (str): url to connect to for distributed training, including protocol
+                       e.g. "tcp://127.0.0.1:8686".
+                       Can be set to auto to automatically select a free port on localhost
+        args (tuple): arguments passed to main_func
+    """
+    world_size = num_machines * num_gpus_per_machine
+    if world_size > 1:
+        # https://github.com/pytorch/pytorch/pull/14391
+        # TODO prctl in spawned processes
+
+        if dist_url == "auto":
+            assert (
+                num_machines == 1
+            ), "dist_url=auto cannot work with distributed training."
+            port = _find_free_port()
+            dist_url = f"tcp://127.0.0.1:{port}"
+
+        start_method = "spawn"
+        cache = vars(args[1]).get("cache", False)
+
+        # To use numpy memmap for caching image into RAM, we have to use fork method
+        if cache:
+            assert sys.platform != "win32", (
+                "As Windows platform doesn't support fork method, "
+                "do not add --cache in your training command."
+            )
+            start_method = "fork"
+
+        mp.start_processes(
+            _distributed_worker,
+            nprocs=num_gpus_per_machine,
+            args=(
+                main_func,
+                world_size,
+                num_gpus_per_machine,
+                machine_rank,
+                backend,
+                dist_url,
+                args,
+            ),
+            daemon=False,
+            start_method=start_method,
+        )
+    else:
+        main_func(*args)
+
+
+def _distributed_worker(
+    local_rank,
+    main_func,
+    world_size,
+    num_gpus_per_machine,
+    machine_rank,
+    backend,
+    dist_url,
+    args,
+    timeout=DEFAULT_TIMEOUT,
+):
+    assert (
+        torch.cuda.is_available()
+    ), "cuda is not available. Please check your installation."
+    global_rank = machine_rank * num_gpus_per_machine + local_rank
+    logger.info("Rank {} initialization finished.".format(global_rank))
+    try:
+        dist.init_process_group(
+            backend=backend,
+            init_method=dist_url,
+            world_size=world_size,
+            rank=global_rank,
+            timeout=timeout,
+        )
+    except Exception:
+        logger.error("Process group URL: {}".format(dist_url))
+        raise
+
+    # Setup the local process group (which contains ranks within the same machine)
+    assert comm._LOCAL_PROCESS_GROUP is None
+    num_machines = world_size // num_gpus_per_machine
+    for i in range(num_machines):
+        ranks_on_i = list(
+            range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)
+        )
+        pg = dist.new_group(ranks_on_i)
+        if i == machine_rank:
+            comm._LOCAL_PROCESS_GROUP = pg
+
+    # synchronize is needed here to prevent a possible timeout after calling init_process_group
+    # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
+    comm.synchronize()
+
+    assert num_gpus_per_machine <= torch.cuda.device_count()
+    torch.cuda.set_device(local_rank)
+
+    main_func(*args)
diff --git a/multimodal/YOLOX/yolox/core/trainer.py b/multimodal/YOLOX/yolox/core/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a76442680b64be32af7e21d90e786eac7059c22d
--- /dev/null
+++ b/multimodal/YOLOX/yolox/core/trainer.py
@@ -0,0 +1,390 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import datetime
+import os
+import time
+from loguru import logger
+
+import torch
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+
+from yolox.data import DataPrefetcher
+from yolox.exp import Exp
+from yolox.utils import (
+    MeterBuffer,
+    ModelEMA,
+    WandbLogger,
+    adjust_status,
+    all_reduce_norm,
+    get_local_rank,
+    get_model_info,
+    get_rank,
+    get_world_size,
+    gpu_mem_usage,
+    is_parallel,
+    load_ckpt,
+    mem_usage,
+    occupy_mem,
+    save_checkpoint,
+    setup_logger,
+    synchronize
+)
+
+
+class Trainer:
+    def __init__(self, exp: Exp, args):
+        # init function only defines some basic attr, other attrs like model, optimizer are built in
+        # before_train methods.
+        self.exp = exp
+        self.args = args
+
+        # training related attr
+        self.max_epoch = exp.max_epoch
+        self.amp_training = args.fp16
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+        self.is_distributed = get_world_size() > 1
+        self.rank = get_rank()
+        self.local_rank = get_local_rank()
+        self.device = "cuda:{}".format(self.local_rank)
+        self.use_model_ema = exp.ema
+        self.save_history_ckpt = exp.save_history_ckpt
+
+        # data/dataloader related attr
+        self.data_type = torch.float16 if args.fp16 else torch.float32
+        self.input_size = exp.input_size
+        self.best_ap = 0
+
+        # metric record
+        self.meter = MeterBuffer(window_size=exp.print_interval)
+        self.file_name = os.path.join(exp.output_dir, args.experiment_name)
+
+        if self.rank == 0:
+            os.makedirs(self.file_name, exist_ok=True)
+
+        setup_logger(
+            self.file_name,
+            distributed_rank=self.rank,
+            filename="train_log.txt",
+            mode="a",
+        )
+
+    def train(self):
+        self.before_train()
+        try:
+            self.train_in_epoch()
+        except Exception:
+            raise
+        finally:
+            self.after_train()
+
+    def train_in_epoch(self):
+        for self.epoch in range(self.start_epoch, self.max_epoch):
+            self.before_epoch()
+            self.train_in_iter()
+            self.after_epoch()
+
+    def train_in_iter(self):
+        for self.iter in range(self.max_iter):
+            self.before_iter()
+            self.train_one_iter()
+            self.after_iter()
+
+    def train_one_iter(self):
+        iter_start_time = time.time()
+
+        inps, targets = self.prefetcher.next()
+        inps = inps.to(self.data_type)
+        targets = targets.to(self.data_type)
+        targets.requires_grad = False
+        inps, targets = self.exp.preprocess(inps, targets, self.input_size)
+        data_end_time = time.time()
+
+        with torch.cuda.amp.autocast(enabled=self.amp_training):
+            outputs = self.model(inps, targets)
+
+        loss = outputs["total_loss"]
+
+        self.optimizer.zero_grad()
+        self.scaler.scale(loss).backward()
+        self.scaler.step(self.optimizer)
+        self.scaler.update()
+
+        if self.use_model_ema:
+            self.ema_model.update(self.model)
+
+        lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
+        for param_group in self.optimizer.param_groups:
+            param_group["lr"] = lr
+
+        iter_end_time = time.time()
+        self.meter.update(
+            iter_time=iter_end_time - iter_start_time,
+            data_time=data_end_time - iter_start_time,
+            lr=lr,
+            **outputs,
+        )
+
+    def before_train(self):
+        logger.info("args: {}".format(self.args))
+        logger.info("exp value:\n{}".format(self.exp))
+
+        # model related init
+        torch.cuda.set_device(self.local_rank)
+        model = self.exp.get_model()
+        logger.info(
+            "Model Summary: {}".format(get_model_info(model, self.exp.test_size))
+        )
+        model.to(self.device)
+
+        # solver related init
+        self.optimizer = self.exp.get_optimizer(self.args.batch_size)
+
+        # value of epoch will be set in `resume_train`
+        model = self.resume_train(model)
+
+        # data related init
+        self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs
+        self.train_loader = self.exp.get_data_loader(
+            batch_size=self.args.batch_size,
+            is_distributed=self.is_distributed,
+            no_aug=self.no_aug,
+            cache_img=self.args.cache,
+        )
+        logger.info("init prefetcher, this might take one minute or less...")
+        self.prefetcher = DataPrefetcher(self.train_loader)
+        # max_iter means iters per epoch
+        self.max_iter = len(self.train_loader)
+
+        self.lr_scheduler = self.exp.get_lr_scheduler(
+            self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
+        )
+        if self.args.occupy:
+            occupy_mem(self.local_rank)
+
+        if self.is_distributed:
+            model = DDP(model, device_ids=[self.local_rank], broadcast_buffers=False)
+
+        if self.use_model_ema:
+            self.ema_model = ModelEMA(model, 0.9998)
+            self.ema_model.updates = self.max_iter * self.start_epoch
+
+        self.model = model
+
+        self.evaluator = self.exp.get_evaluator(
+            batch_size=self.args.batch_size, is_distributed=self.is_distributed
+        )
+        # Tensorboard and Wandb loggers
+        if self.rank == 0:
+            if self.args.logger == "tensorboard":
+                self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
+            elif self.args.logger == "wandb":
+                self.wandb_logger = WandbLogger.initialize_wandb_logger(
+                    self.args,
+                    self.exp,
+                    self.evaluator.dataloader.dataset
+                )
+            else:
+                raise ValueError("logger must be either 'tensorboard' or 'wandb'")
+
+        logger.info("Training start...")
+        logger.info("\n{}".format(model))
+
+    def after_train(self):
+        logger.info(
+            "Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
+        )
+        if self.rank == 0:
+            if self.args.logger == "wandb":
+                self.wandb_logger.finish()
+
+    def before_epoch(self):
+        logger.info("---> start train epoch{}".format(self.epoch + 1))
+
+        if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
+            logger.info("--->No mosaic aug now!")
+            self.train_loader.close_mosaic()
+            logger.info("--->Add additional L1 loss now!")
+            if self.is_distributed:
+                self.model.module.head.use_l1 = True
+            else:
+                self.model.head.use_l1 = True
+            self.exp.eval_interval = 1
+            if not self.no_aug:
+                self.save_ckpt(ckpt_name="last_mosaic_epoch")
+
+    def after_epoch(self):
+        self.save_ckpt(ckpt_name="latest")
+
+        if (self.epoch + 1) % self.exp.eval_interval == 0:
+            all_reduce_norm(self.model)
+            self.evaluate_and_save_model()
+
+    def before_iter(self):
+        pass
+
+    def after_iter(self):
+        """
+        `after_iter` contains two parts of logic:
+            * log information
+            * reset setting of resize
+        """
+        # log needed information
+        if (self.iter + 1) % self.exp.print_interval == 0:
+            # TODO check ETA logic
+            left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1)
+            eta_seconds = self.meter["iter_time"].global_avg * left_iters
+            eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
+
+            progress_str = "epoch: {}/{}, iter: {}/{}".format(
+                self.epoch + 1, self.max_epoch, self.iter + 1, self.max_iter
+            )
+            loss_meter = self.meter.get_filtered_meter("loss")
+            loss_str = ", ".join(
+                ["{}: {:.1f}".format(k, v.latest) for k, v in loss_meter.items()]
+            )
+
+            time_meter = self.meter.get_filtered_meter("time")
+            time_str = ", ".join(
+                ["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
+            )
+
+            mem_str = "gpu mem: {:.0f}Mb, mem: {:.1f}Gb".format(gpu_mem_usage(), mem_usage())
+
+            logger.info(
+                "{}, {}, {}, {}, lr: {:.3e}".format(
+                    progress_str,
+                    mem_str,
+                    time_str,
+                    loss_str,
+                    self.meter["lr"].latest,
+                )
+                + (", size: {:d}, {}".format(self.input_size[0], eta_str))
+            )
+
+            if self.rank == 0:
+                if self.args.logger == "tensorboard":
+                    self.tblogger.add_scalar(
+                        "train/lr", self.meter["lr"].latest, self.progress_in_iter)
+                    for k, v in loss_meter.items():
+                        self.tblogger.add_scalar(
+                            f"train/{k}", v.latest, self.progress_in_iter)
+                if self.args.logger == "wandb":
+                    metrics = {"train/" + k: v.latest for k, v in loss_meter.items()}
+                    metrics.update({
+                        "train/lr": self.meter["lr"].latest
+                    })
+                    self.wandb_logger.log_metrics(metrics, step=self.progress_in_iter)
+
+            self.meter.clear_meters()
+
+        # random resizing
+        if (self.progress_in_iter + 1) % 10 == 0:
+            self.input_size = self.exp.random_resize(
+                self.train_loader, self.epoch, self.rank, self.is_distributed
+            )
+
+    @property
+    def progress_in_iter(self):
+        return self.epoch * self.max_iter + self.iter
+
+    def resume_train(self, model):
+        if self.args.resume:
+            logger.info("resume training")
+            if self.args.ckpt is None:
+                ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth")
+            else:
+                ckpt_file = self.args.ckpt
+
+            ckpt = torch.load(ckpt_file, map_location=self.device)
+            # resume the model/optimizer state dict
+            model.load_state_dict(ckpt["model"])
+            self.optimizer.load_state_dict(ckpt["optimizer"])
+            self.best_ap = ckpt.pop("best_ap", 0)
+            # resume the training states variables
+            start_epoch = (
+                self.args.start_epoch - 1
+                if self.args.start_epoch is not None
+                else ckpt["start_epoch"]
+            )
+            self.start_epoch = start_epoch
+            logger.info(
+                "loaded checkpoint '{}' (epoch {})".format(
+                    self.args.resume, self.start_epoch
+                )
+            )  # noqa
+        else:
+            if self.args.ckpt is not None:
+                logger.info("loading checkpoint for fine tuning")
+                ckpt_file = self.args.ckpt
+                ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
+                model = load_ckpt(model, ckpt)
+            self.start_epoch = 0
+
+        return model
+
+    def evaluate_and_save_model(self):
+        if self.use_model_ema:
+            evalmodel = self.ema_model.ema
+        else:
+            evalmodel = self.model
+            if is_parallel(evalmodel):
+                evalmodel = evalmodel.module
+
+        with adjust_status(evalmodel, training=False):
+            (ap50_95, ap50, summary), predictions = self.exp.eval(
+                evalmodel, self.evaluator, self.is_distributed, return_outputs=True
+            )
+
+        update_best_ckpt = ap50_95 > self.best_ap
+        self.best_ap = max(self.best_ap, ap50_95)
+
+        if self.rank == 0:
+            if self.args.logger == "tensorboard":
+                self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
+                self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
+            if self.args.logger == "wandb":
+                self.wandb_logger.log_metrics({
+                    "val/COCOAP50": ap50,
+                    "val/COCOAP50_95": ap50_95,
+                    "train/epoch": self.epoch + 1,
+                })
+                self.wandb_logger.log_images(predictions)
+            logger.info("\n" + summary)
+        synchronize()
+
+        self.save_ckpt("last_epoch", update_best_ckpt, ap=ap50_95)
+        if self.save_history_ckpt:
+            self.save_ckpt(f"epoch_{self.epoch + 1}", ap=ap50_95)
+
+    def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
+        if self.rank == 0:
+            save_model = self.ema_model.ema if self.use_model_ema else self.model
+            logger.info("Save weights to {}".format(self.file_name))
+            ckpt_state = {
+                "start_epoch": self.epoch + 1,
+                "model": save_model.state_dict(),
+                "optimizer": self.optimizer.state_dict(),
+                "best_ap": self.best_ap,
+                "curr_ap": ap,
+            }
+            save_checkpoint(
+                ckpt_state,
+                update_best_ckpt,
+                self.file_name,
+                ckpt_name,
+            )
+
+            if self.args.logger == "wandb":
+                self.wandb_logger.save_checkpoint(
+                    self.file_name,
+                    ckpt_name,
+                    update_best_ckpt,
+                    metadata={
+                        "epoch": self.epoch + 1,
+                        "optimizer": self.optimizer.state_dict(),
+                        "best_ap": self.best_ap,
+                        "curr_ap": ap
+                    }
+                )
diff --git a/multimodal/YOLOX/yolox/data/__init__.py b/multimodal/YOLOX/yolox/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeaf4f930ab8b9890ca43ba031f5b035be623ccd
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/__init__.py
@@ -0,0 +1,9 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+from .data_augment import TrainTransform, ValTransform
+from .data_prefetcher import DataPrefetcher
+from .dataloading import DataLoader, get_yolox_datadir, worker_init_reset_seed
+from .datasets import *
+from .samplers import InfiniteSampler, YoloBatchSampler
diff --git a/multimodal/YOLOX/yolox/data/data_augment.py b/multimodal/YOLOX/yolox/data/data_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..21cd7b56d800a38d3782bf5072c03f9b2f9bb809
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/data_augment.py
@@ -0,0 +1,243 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+"""
+Data augmentation functionality. Passed as callable transformations to
+Dataset classes.
+
+The data augmentation procedures were interpreted from @weiliu89's SSD paper
+http://arxiv.org/abs/1512.02325
+"""
+
+import math
+import random
+
+import cv2
+import numpy as np
+
+from yolox.utils import xyxy2cxcywh
+
+
+def augment_hsv(img, hgain=5, sgain=30, vgain=30):
+    hsv_augs = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain]  # random gains
+    hsv_augs *= np.random.randint(0, 2, 3)  # random selection of h, s, v
+    hsv_augs = hsv_augs.astype(np.int16)
+    img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.int16)
+
+    img_hsv[..., 0] = (img_hsv[..., 0] + hsv_augs[0]) % 180
+    img_hsv[..., 1] = np.clip(img_hsv[..., 1] + hsv_augs[1], 0, 255)
+    img_hsv[..., 2] = np.clip(img_hsv[..., 2] + hsv_augs[2], 0, 255)
+
+    cv2.cvtColor(img_hsv.astype(img.dtype), cv2.COLOR_HSV2BGR, dst=img)  # no return needed
+
+
+def get_aug_params(value, center=0):
+    if isinstance(value, float):
+        return random.uniform(center - value, center + value)
+    elif len(value) == 2:
+        return random.uniform(value[0], value[1])
+    else:
+        raise ValueError(
+            "Affine params should be either a sequence containing two values\
+             or single float values. Got {}".format(value)
+        )
+
+
+def get_affine_matrix(
+    target_size,
+    degrees=10,
+    translate=0.1,
+    scales=0.1,
+    shear=10,
+):
+    twidth, theight = target_size
+
+    # Rotation and Scale
+    angle = get_aug_params(degrees)
+    scale = get_aug_params(scales, center=1.0)
+
+    if scale <= 0.0:
+        raise ValueError("Argument scale should be positive")
+
+    R = cv2.getRotationMatrix2D(angle=angle, center=(0, 0), scale=scale)
+
+    M = np.ones([2, 3])
+    # Shear
+    shear_x = math.tan(get_aug_params(shear) * math.pi / 180)
+    shear_y = math.tan(get_aug_params(shear) * math.pi / 180)
+
+    M[0] = R[0] + shear_y * R[1]
+    M[1] = R[1] + shear_x * R[0]
+
+    # Translation
+    translation_x = get_aug_params(translate) * twidth  # x translation (pixels)
+    translation_y = get_aug_params(translate) * theight  # y translation (pixels)
+
+    M[0, 2] = translation_x
+    M[1, 2] = translation_y
+
+    return M, scale
+
+
+def apply_affine_to_bboxes(targets, target_size, M, scale):
+    num_gts = len(targets)
+
+    # warp corner points
+    twidth, theight = target_size
+    corner_points = np.ones((4 * num_gts, 3))
+    corner_points[:, :2] = targets[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
+        4 * num_gts, 2
+    )  # x1y1, x2y2, x1y2, x2y1
+    corner_points = corner_points @ M.T  # apply affine transform
+    corner_points = corner_points.reshape(num_gts, 8)
+
+    # create new boxes
+    corner_xs = corner_points[:, 0::2]
+    corner_ys = corner_points[:, 1::2]
+    new_bboxes = (
+        np.concatenate(
+            (corner_xs.min(1), corner_ys.min(1), corner_xs.max(1), corner_ys.max(1))
+        )
+        .reshape(4, num_gts)
+        .T
+    )
+
+    # clip boxes
+    new_bboxes[:, 0::2] = new_bboxes[:, 0::2].clip(0, twidth)
+    new_bboxes[:, 1::2] = new_bboxes[:, 1::2].clip(0, theight)
+
+    targets[:, :4] = new_bboxes
+
+    return targets
+
+
+def random_affine(
+    img,
+    targets=(),
+    target_size=(640, 640),
+    degrees=10,
+    translate=0.1,
+    scales=0.1,
+    shear=10,
+):
+    M, scale = get_affine_matrix(target_size, degrees, translate, scales, shear)
+
+    img = cv2.warpAffine(img, M, dsize=target_size, borderValue=(114, 114, 114))
+
+    # Transform label coordinates
+    if len(targets) > 0:
+        targets = apply_affine_to_bboxes(targets, target_size, M, scale)
+
+    return img, targets
+
+
+def _mirror(image, boxes, prob=0.5):
+    _, width, _ = image.shape
+    if random.random() < prob:
+        image = image[:, ::-1]
+        boxes[:, 0::2] = width - boxes[:, 2::-2]
+    return image, boxes
+
+
+def preproc(img, input_size, swap=(2, 0, 1)):
+    if len(img.shape) == 3:
+        padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
+    else:
+        padded_img = np.ones(input_size, dtype=np.uint8) * 114
+
+    r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
+    resized_img = cv2.resize(
+        img,
+        (int(img.shape[1] * r), int(img.shape[0] * r)),
+        interpolation=cv2.INTER_LINEAR,
+    ).astype(np.uint8)
+    padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
+
+    padded_img = padded_img.transpose(swap)
+    padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+    return padded_img, r
+
+
+class TrainTransform:
+    def __init__(self, max_labels=50, flip_prob=0.5, hsv_prob=1.0):
+        self.max_labels = max_labels
+        self.flip_prob = flip_prob
+        self.hsv_prob = hsv_prob
+
+    def __call__(self, image, targets, input_dim):
+        boxes = targets[:, :4].copy()
+        labels = targets[:, 4].copy()
+        if len(boxes) == 0:
+            targets = np.zeros((self.max_labels, 5), dtype=np.float32)
+            image, r_o = preproc(image, input_dim)
+            return image, targets
+
+        image_o = image.copy()
+        targets_o = targets.copy()
+        height_o, width_o, _ = image_o.shape
+        boxes_o = targets_o[:, :4]
+        labels_o = targets_o[:, 4]
+        # bbox_o: [xyxy] to [c_x,c_y,w,h]
+        boxes_o = xyxy2cxcywh(boxes_o)
+
+        if random.random() < self.hsv_prob:
+            augment_hsv(image)
+        image_t, boxes = _mirror(image, boxes, self.flip_prob)
+        height, width, _ = image_t.shape
+        image_t, r_ = preproc(image_t, input_dim)
+        # boxes [xyxy] 2 [cx,cy,w,h]
+        boxes = xyxy2cxcywh(boxes)
+        boxes *= r_
+
+        mask_b = np.minimum(boxes[:, 2], boxes[:, 3]) > 1
+        boxes_t = boxes[mask_b]
+        labels_t = labels[mask_b]
+
+        if len(boxes_t) == 0:
+            image_t, r_o = preproc(image_o, input_dim)
+            boxes_o *= r_o
+            boxes_t = boxes_o
+            labels_t = labels_o
+
+        labels_t = np.expand_dims(labels_t, 1)
+
+        targets_t = np.hstack((labels_t, boxes_t))
+        padded_labels = np.zeros((self.max_labels, 5))
+        padded_labels[range(len(targets_t))[: self.max_labels]] = targets_t[
+            : self.max_labels
+        ]
+        padded_labels = np.ascontiguousarray(padded_labels, dtype=np.float32)
+        return image_t, padded_labels
+
+
+class ValTransform:
+    """
+    Defines the transformations that should be applied to test PIL image
+    for input into the network
+
+    dimension -> tensorize -> color adj
+
+    Arguments:
+        resize (int): input dimension to SSD
+        rgb_means ((int,int,int)): average RGB of the dataset
+            (104,117,123)
+        swap ((int,int,int)): final order of channels
+
+    Returns:
+        transform (transform) : callable transform to be applied to test/val
+        data
+    """
+
+    def __init__(self, swap=(2, 0, 1), legacy=False):
+        self.swap = swap
+        self.legacy = legacy
+
+    # assume input is cv2 img for now
+    def __call__(self, img, res, input_size):
+        img, _ = preproc(img, input_size, self.swap)
+        if self.legacy:
+            img = img[::-1, :, :].copy()
+            img /= 255.0
+            img -= np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
+            img /= np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)
+        return img, np.zeros((1, 5))
diff --git a/multimodal/YOLOX/yolox/data/data_prefetcher.py b/multimodal/YOLOX/yolox/data/data_prefetcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..a118cf4e4ef968c9cf89a72457ede8c63bdf2cce
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/data_prefetcher.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import torch
+
+
+class DataPrefetcher:
+    """
+    DataPrefetcher is inspired by code of following file:
+    https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
+    It could speedup your pytorch dataloader. For more information, please check
+    https://github.com/NVIDIA/apex/issues/304#issuecomment-493562789.
+    """
+
+    def __init__(self, loader):
+        self.loader = iter(loader)
+        self.stream = torch.cuda.Stream()
+        self.input_cuda = self._input_cuda_for_image
+        self.record_stream = DataPrefetcher._record_stream_for_image
+        self.preload()
+
+    def preload(self):
+        try:
+            self.next_input, self.next_target, _, _ = next(self.loader)
+        except StopIteration:
+            self.next_input = None
+            self.next_target = None
+            return
+
+        with torch.cuda.stream(self.stream):
+            self.input_cuda()
+            self.next_target = self.next_target.cuda(non_blocking=True)
+
+    def next(self):
+        torch.cuda.current_stream().wait_stream(self.stream)
+        input = self.next_input
+        target = self.next_target
+        if input is not None:
+            self.record_stream(input)
+        if target is not None:
+            target.record_stream(torch.cuda.current_stream())
+        self.preload()
+        return input, target
+
+    def _input_cuda_for_image(self):
+        self.next_input = self.next_input.cuda(non_blocking=True)
+
+    @staticmethod
+    def _record_stream_for_image(input):
+        input.record_stream(torch.cuda.current_stream())
diff --git a/multimodal/YOLOX/yolox/data/dataloading.py b/multimodal/YOLOX/yolox/data/dataloading.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fecf3f06abe908ea5f0d84fba85d2e230257512
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/dataloading.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+import random
+import uuid
+
+import numpy as np
+
+import torch
+from torch.utils.data.dataloader import DataLoader as torchDataLoader
+from torch.utils.data.dataloader import default_collate
+
+from .samplers import YoloBatchSampler
+
+
+def get_yolox_datadir():
+    """
+    get dataset dir of YOLOX. If environment variable named `YOLOX_DATADIR` is set,
+    this function will return value of the environment variable. Otherwise, use data
+    """
+    yolox_datadir = os.getenv("YOLOX_DATADIR", None)
+    if yolox_datadir is None:
+        import yolox
+
+        yolox_path = os.path.dirname(os.path.dirname(yolox.__file__))
+        yolox_datadir = os.path.join(yolox_path, "datasets")
+    return yolox_datadir
+
+
+class DataLoader(torchDataLoader):
+    """
+    Lightnet dataloader that enables on the fly resizing of the images.
+    See :class:`torch.utils.data.DataLoader` for more information on the arguments.
+    Check more on the following website:
+    https://gitlab.com/EAVISE/lightnet/-/blob/master/lightnet/data/_dataloading.py
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.__initialized = False
+        shuffle = False
+        batch_sampler = None
+        if len(args) > 5:
+            shuffle = args[2]
+            sampler = args[3]
+            batch_sampler = args[4]
+        elif len(args) > 4:
+            shuffle = args[2]
+            sampler = args[3]
+            if "batch_sampler" in kwargs:
+                batch_sampler = kwargs["batch_sampler"]
+        elif len(args) > 3:
+            shuffle = args[2]
+            if "sampler" in kwargs:
+                sampler = kwargs["sampler"]
+            if "batch_sampler" in kwargs:
+                batch_sampler = kwargs["batch_sampler"]
+        else:
+            if "shuffle" in kwargs:
+                shuffle = kwargs["shuffle"]
+            if "sampler" in kwargs:
+                sampler = kwargs["sampler"]
+            if "batch_sampler" in kwargs:
+                batch_sampler = kwargs["batch_sampler"]
+
+        # Use custom BatchSampler
+        if batch_sampler is None:
+            if sampler is None:
+                if shuffle:
+                    sampler = torch.utils.data.sampler.RandomSampler(self.dataset)
+                    # sampler = torch.utils.data.DistributedSampler(self.dataset)
+                else:
+                    sampler = torch.utils.data.sampler.SequentialSampler(self.dataset)
+            batch_sampler = YoloBatchSampler(
+                sampler,
+                self.batch_size,
+                self.drop_last,
+                input_dimension=self.dataset.input_dim,
+            )
+            # batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iterations =
+
+        self.batch_sampler = batch_sampler
+
+        self.__initialized = True
+
+    def close_mosaic(self):
+        self.batch_sampler.mosaic = False
+
+
+def list_collate(batch):
+    """
+    Function that collates lists or tuples together into one list (of lists/tuples).
+    Use this as the collate function in a Dataloader, if you want to have a list of
+    items as an output, as opposed to tensors (eg. Brambox.boxes).
+    """
+    items = list(zip(*batch))
+
+    for i in range(len(items)):
+        if isinstance(items[i][0], (list, tuple)):
+            items[i] = list(items[i])
+        else:
+            items[i] = default_collate(items[i])
+
+    return items
+
+
+def worker_init_reset_seed(worker_id):
+    seed = uuid.uuid4().int % 2**32
+    random.seed(seed)
+    torch.set_rng_state(torch.manual_seed(seed).get_state())
+    np.random.seed(seed)
diff --git a/multimodal/YOLOX/yolox/data/datasets/__init__.py b/multimodal/YOLOX/yolox/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b6fd8ec4cecffe94d80084b57f3b966e4f01def
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/datasets/__init__.py
@@ -0,0 +1,9 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+from .coco import COCODataset
+from .coco_classes import COCO_CLASSES
+from .datasets_wrapper import CacheDataset, ConcatDataset, Dataset, MixConcatDataset
+from .mosaicdetection import MosaicDetection
+from .voc import VOCDetection
diff --git a/multimodal/YOLOX/yolox/data/datasets/coco.py b/multimodal/YOLOX/yolox/data/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d19047a2bdef1c2a1af544d484cb2eee3af8aaa
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/datasets/coco.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+import copy
+import os
+
+import cv2
+import numpy as np
+from pycocotools.coco import COCO
+
+from ..dataloading import get_yolox_datadir
+from .datasets_wrapper import CacheDataset, cache_read_img
+
+
+def remove_useless_info(coco):
+    """
+    Remove useless info in coco dataset. COCO object is modified inplace.
+    This function is mainly used for saving memory (save about 30% mem).
+    """
+    if isinstance(coco, COCO):
+        dataset = coco.dataset
+        dataset.pop("info", None)
+        dataset.pop("licenses", None)
+        for img in dataset["images"]:
+            img.pop("license", None)
+            img.pop("coco_url", None)
+            img.pop("date_captured", None)
+            img.pop("flickr_url", None)
+        if "annotations" in coco.dataset:
+            for anno in coco.dataset["annotations"]:
+                anno.pop("segmentation", None)
+
+
+class COCODataset(CacheDataset):
+    """
+    COCO dataset class.
+    """
+
+    def __init__(
+        self,
+        data_dir=None,
+        json_file="instances_train2017.json",
+        name="train2017",
+        img_size=(416, 416),
+        preproc=None,
+        cache=False,
+        cache_type="ram",
+    ):
+        """
+        COCO dataset initialization. Annotation data are read into memory by COCO API.
+        Args:
+            data_dir (str): dataset root directory
+            json_file (str): COCO json file name
+            name (str): COCO data name (e.g. 'train2017' or 'val2017')
+            img_size (int): target image size after pre-processing
+            preproc: data augmentation strategy
+        """
+        if data_dir is None:
+            data_dir = os.path.join(get_yolox_datadir(), "COCO")
+        self.data_dir = data_dir
+        self.json_file = json_file
+
+        self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
+        remove_useless_info(self.coco)
+        self.ids = self.coco.getImgIds()
+        self.num_imgs = len(self.ids)
+        self.class_ids = sorted(self.coco.getCatIds())
+        self.cats = self.coco.loadCats(self.coco.getCatIds())
+        self._classes = tuple([c["name"] for c in self.cats])
+        self.name = name
+        self.img_size = img_size
+        self.preproc = preproc
+        self.annotations = self._load_coco_annotations()
+
+        path_filename = [os.path.join(name, anno[3]) for anno in self.annotations]
+        super().__init__(
+            input_dimension=img_size,
+            num_imgs=self.num_imgs,
+            data_dir=data_dir,
+            cache_dir_name=f"cache_{name}",
+            path_filename=path_filename,
+            cache=cache,
+            cache_type=cache_type
+        )
+
+    def __len__(self):
+        return self.num_imgs
+
+    def _load_coco_annotations(self):
+        return [self.load_anno_from_ids(_ids) for _ids in self.ids]
+
+    def load_anno_from_ids(self, id_):
+        im_ann = self.coco.loadImgs(id_)[0]
+        width = im_ann["width"]
+        height = im_ann["height"]
+        anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
+        annotations = self.coco.loadAnns(anno_ids)
+        objs = []
+        for obj in annotations:
+            x1 = np.max((0, obj["bbox"][0]))
+            y1 = np.max((0, obj["bbox"][1]))
+            x2 = np.min((width, x1 + np.max((0, obj["bbox"][2]))))
+            y2 = np.min((height, y1 + np.max((0, obj["bbox"][3]))))
+            if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
+                obj["clean_bbox"] = [x1, y1, x2, y2]
+                objs.append(obj)
+
+        num_objs = len(objs)
+
+        res = np.zeros((num_objs, 5))
+        for ix, obj in enumerate(objs):
+            cls = self.class_ids.index(obj["category_id"])
+            res[ix, 0:4] = obj["clean_bbox"]
+            res[ix, 4] = cls
+
+        r = min(self.img_size[0] / height, self.img_size[1] / width)
+        res[:, :4] *= r
+
+        img_info = (height, width)
+        resized_info = (int(height * r), int(width * r))
+
+        file_name = (
+            im_ann["file_name"]
+            if "file_name" in im_ann
+            else "{:012}".format(id_) + ".jpg"
+        )
+
+        return (res, img_info, resized_info, file_name)
+
+    def load_anno(self, index):
+        return self.annotations[index][0]
+
+    def load_resized_img(self, index):
+        img = self.load_image(index)
+        r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
+        resized_img = cv2.resize(
+            img,
+            (int(img.shape[1] * r), int(img.shape[0] * r)),
+            interpolation=cv2.INTER_LINEAR,
+        ).astype(np.uint8)
+        return resized_img
+
+    def load_image(self, index):
+        file_name = self.annotations[index][3]
+
+        img_file = os.path.join(self.data_dir, self.name, file_name)
+
+        img = cv2.imread(img_file)
+        assert img is not None, f"file named {img_file} not found"
+
+        return img
+
+    @cache_read_img(use_cache=True)
+    def read_img(self, index):
+        return self.load_resized_img(index)
+
+    def pull_item(self, index):
+        id_ = self.ids[index]
+        label, origin_image_size, _, _ = self.annotations[index]
+        img = self.read_img(index)
+
+        return img, copy.deepcopy(label), origin_image_size, np.array([id_])
+
+    @CacheDataset.mosaic_getitem
+    def __getitem__(self, index):
+        """
+        One image / label pair for the given index is picked up and pre-processed.
+
+        Args:
+            index (int): data index
+
+        Returns:
+            img (numpy.ndarray): pre-processed image
+            padded_labels (torch.Tensor): pre-processed label data.
+                The shape is :math:`[max_labels, 5]`.
+                each label consists of [class, xc, yc, w, h]:
+                    class (float): class index.
+                    xc, yc (float) : center of bbox whose values range from 0 to 1.
+                    w, h (float) : size of bbox whose values range from 0 to 1.
+            info_img : tuple of h, w.
+                h, w (int): original shape of the image
+            img_id (int): same as the input index. Used for evaluation.
+        """
+        img, target, img_info, img_id = self.pull_item(index)
+
+        if self.preproc is not None:
+            img, target = self.preproc(img, target, self.input_dim)
+        return img, target, img_info, img_id
diff --git a/multimodal/YOLOX/yolox/data/datasets/coco_classes.py b/multimodal/YOLOX/yolox/data/datasets/coco_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..17f5cbe6e86ed4fc8378760da71f8349a6406ff1
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/datasets/coco_classes.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+COCO_CLASSES = (
+    "person",
+    "bicycle",
+    "car",
+    "motorcycle",
+    "airplane",
+    "bus",
+    "train",
+    "truck",
+    "boat",
+    "traffic light",
+    "fire hydrant",
+    "stop sign",
+    "parking meter",
+    "bench",
+    "bird",
+    "cat",
+    "dog",
+    "horse",
+    "sheep",
+    "cow",
+    "elephant",
+    "bear",
+    "zebra",
+    "giraffe",
+    "backpack",
+    "umbrella",
+    "handbag",
+    "tie",
+    "suitcase",
+    "frisbee",
+    "skis",
+    "snowboard",
+    "sports ball",
+    "kite",
+    "baseball bat",
+    "baseball glove",
+    "skateboard",
+    "surfboard",
+    "tennis racket",
+    "bottle",
+    "wine glass",
+    "cup",
+    "fork",
+    "knife",
+    "spoon",
+    "bowl",
+    "banana",
+    "apple",
+    "sandwich",
+    "orange",
+    "broccoli",
+    "carrot",
+    "hot dog",
+    "pizza",
+    "donut",
+    "cake",
+    "chair",
+    "couch",
+    "potted plant",
+    "bed",
+    "dining table",
+    "toilet",
+    "tv",
+    "laptop",
+    "mouse",
+    "remote",
+    "keyboard",
+    "cell phone",
+    "microwave",
+    "oven",
+    "toaster",
+    "sink",
+    "refrigerator",
+    "book",
+    "clock",
+    "vase",
+    "scissors",
+    "teddy bear",
+    "hair drier",
+    "toothbrush",
+)
diff --git a/multimodal/YOLOX/yolox/data/datasets/datasets_wrapper.py b/multimodal/YOLOX/yolox/data/datasets/datasets_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c45fe380f5b7ac1c40452ff3903da651fe324225
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/datasets/datasets_wrapper.py
@@ -0,0 +1,300 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import bisect
+import copy
+import os
+import random
+from abc import ABCMeta, abstractmethod
+from functools import partial, wraps
+from multiprocessing.pool import ThreadPool
+import psutil
+from loguru import logger
+from tqdm import tqdm
+
+import numpy as np
+
+from torch.utils.data.dataset import ConcatDataset as torchConcatDataset
+from torch.utils.data.dataset import Dataset as torchDataset
+
+
+class ConcatDataset(torchConcatDataset):
+    def __init__(self, datasets):
+        super(ConcatDataset, self).__init__(datasets)
+        if hasattr(self.datasets[0], "input_dim"):
+            self._input_dim = self.datasets[0].input_dim
+            self.input_dim = self.datasets[0].input_dim
+
+    def pull_item(self, idx):
+        if idx < 0:
+            if -idx > len(self):
+                raise ValueError(
+                    "absolute value of index should not exceed dataset length"
+                )
+            idx = len(self) + idx
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+        return self.datasets[dataset_idx].pull_item(sample_idx)
+
+
+class MixConcatDataset(torchConcatDataset):
+    def __init__(self, datasets):
+        super(MixConcatDataset, self).__init__(datasets)
+        if hasattr(self.datasets[0], "input_dim"):
+            self._input_dim = self.datasets[0].input_dim
+            self.input_dim = self.datasets[0].input_dim
+
+    def __getitem__(self, index):
+
+        if not isinstance(index, int):
+            idx = index[1]
+        if idx < 0:
+            if -idx > len(self):
+                raise ValueError(
+                    "absolute value of index should not exceed dataset length"
+                )
+            idx = len(self) + idx
+        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+        if dataset_idx == 0:
+            sample_idx = idx
+        else:
+            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+        if not isinstance(index, int):
+            index = (index[0], sample_idx, index[2])
+
+        return self.datasets[dataset_idx][index]
+
+
+class Dataset(torchDataset):
+    """ This class is a subclass of the base :class:`torch.utils.data.Dataset`,
+    that enables on the fly resizing of the ``input_dim``.
+
+    Args:
+        input_dimension (tuple): (width,height) tuple with default dimensions of the network
+    """
+
+    def __init__(self, input_dimension, mosaic=True):
+        super().__init__()
+        self.__input_dim = input_dimension[:2]
+        self.enable_mosaic = mosaic
+
+    @property
+    def input_dim(self):
+        """
+        Dimension that can be used by transforms to set the correct image size, etc.
+        This allows transforms to have a single source of truth
+        for the input dimension of the network.
+
+        Return:
+            list: Tuple containing the current width,height
+        """
+        if hasattr(self, "_input_dim"):
+            return self._input_dim
+        return self.__input_dim
+
+    @staticmethod
+    def mosaic_getitem(getitem_fn):
+        """
+        Decorator method that needs to be used around the ``__getitem__`` method. |br|
+        This decorator enables the closing mosaic
+
+        Example:
+            >>> class CustomSet(ln.data.Dataset):
+            ...     def __len__(self):
+            ...         return 10
+            ...     @ln.data.Dataset.mosaic_getitem
+            ...     def __getitem__(self, index):
+            ...         return self.enable_mosaic
+        """
+
+        @wraps(getitem_fn)
+        def wrapper(self, index):
+            if not isinstance(index, int):
+                self.enable_mosaic = index[0]
+                index = index[1]
+
+            ret_val = getitem_fn(self, index)
+
+            return ret_val
+
+        return wrapper
+
+
+class CacheDataset(Dataset, metaclass=ABCMeta):
+    """ This class is a subclass of the base :class:`yolox.data.datasets.Dataset`,
+    that enables cache images to ram or disk.
+
+    Args:
+        input_dimension (tuple): (width,height) tuple with default dimensions of the network
+        num_imgs (int): datset size
+        data_dir (str): the root directory of the dataset, e.g. `/path/to/COCO`.
+        cache_dir_name (str): the name of the directory to cache to disk,
+            e.g. `"custom_cache"`. The files cached to disk will be saved
+            under `/path/to/COCO/custom_cache`.
+        path_filename (str): a list of paths to the data relative to the `data_dir`,
+            e.g. if you have data `/path/to/COCO/train/1.jpg`, `/path/to/COCO/train/2.jpg`,
+            then `path_filename = ['train/1.jpg', ' train/2.jpg']`.
+        cache (bool): whether to cache the images to ram or disk.
+        cache_type (str): the type of cache,
+            "ram" : Caching imgs to ram for fast training.
+            "disk": Caching imgs to disk for fast training.
+    """
+
+    def __init__(
+        self,
+        input_dimension,
+        num_imgs=None,
+        data_dir=None,
+        cache_dir_name=None,
+        path_filename=None,
+        cache=False,
+        cache_type="ram",
+    ):
+        super().__init__(input_dimension)
+        self.cache = cache
+        self.cache_type = cache_type
+
+        if self.cache and self.cache_type == "disk":
+            self.cache_dir = os.path.join(data_dir, cache_dir_name)
+            self.path_filename = path_filename
+
+        if self.cache and self.cache_type == "ram":
+            self.imgs = None
+
+        if self.cache:
+            self.cache_images(
+                num_imgs=num_imgs,
+                data_dir=data_dir,
+                cache_dir_name=cache_dir_name,
+                path_filename=path_filename,
+            )
+
+    def __del__(self):
+        if self.cache and self.cache_type == "ram":
+            del self.imgs
+
+    @abstractmethod
+    def read_img(self, index):
+        """
+        Given index, return the corresponding image
+
+        Args:
+            index (int): image index
+        """
+        raise NotImplementedError
+
+    def cache_images(
+        self,
+        num_imgs=None,
+        data_dir=None,
+        cache_dir_name=None,
+        path_filename=None,
+    ):
+        assert num_imgs is not None, "num_imgs must be specified as the size of the dataset"
+        if self.cache_type == "disk":
+            assert (data_dir and cache_dir_name and path_filename) is not None, \
+                "data_dir, cache_name and path_filename must be specified if cache_type is disk"
+            self.path_filename = path_filename
+
+        mem = psutil.virtual_memory()
+        mem_required = self.cal_cache_occupy(num_imgs)
+        gb = 1 << 30
+
+        if self.cache_type == "ram":
+            if mem_required > mem.available:
+                self.cache = False
+            else:
+                logger.info(
+                    f"{mem_required / gb:.1f}GB RAM required, "
+                    f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB RAM available, "
+                    f"Since the first thing we do is cache, "
+                    f"there is no guarantee that the remaining memory space is sufficient"
+                )
+
+        if self.cache and self.imgs is None:
+            if self.cache_type == 'ram':
+                self.imgs = [None] * num_imgs
+                logger.info("You are using cached images in RAM to accelerate training!")
+            else:   # 'disk'
+                if not os.path.exists(self.cache_dir):
+                    os.mkdir(self.cache_dir)
+                    logger.warning(
+                        f"\n*******************************************************************\n"
+                        f"You are using cached images in DISK to accelerate training.\n"
+                        f"This requires large DISK space.\n"
+                        f"Make sure you have {mem_required / gb:.1f} "
+                        f"available DISK space for training your dataset.\n"
+                        f"*******************************************************************\\n"
+                    )
+                else:
+                    logger.info(f"Found disk cache at {self.cache_dir}")
+                    return
+
+            logger.info(
+                "Caching images...\n"
+                "This might take some time for your dataset"
+            )
+
+            num_threads = min(8, max(1, os.cpu_count() - 1))
+            b = 0
+            load_imgs = ThreadPool(num_threads).imap(
+                partial(self.read_img, use_cache=False),
+                range(num_imgs)
+            )
+            pbar = tqdm(enumerate(load_imgs), total=num_imgs)
+            for i, x in pbar:   # x = self.read_img(self, i, use_cache=False)
+                if self.cache_type == 'ram':
+                    self.imgs[i] = x
+                else:   # 'disk'
+                    cache_filename = f'{self.path_filename[i].split(".")[0]}.npy'
+                    cache_path_filename = os.path.join(self.cache_dir, cache_filename)
+                    os.makedirs(os.path.dirname(cache_path_filename), exist_ok=True)
+                    np.save(cache_path_filename, x)
+                b += x.nbytes
+                pbar.desc = \
+                    f'Caching images ({b / gb:.1f}/{mem_required / gb:.1f}GB {self.cache_type})'
+            pbar.close()
+
+    def cal_cache_occupy(self, num_imgs):
+        cache_bytes = 0
+        num_samples = min(num_imgs, 32)
+        for _ in range(num_samples):
+            img = self.read_img(index=random.randint(0, num_imgs - 1), use_cache=False)
+            cache_bytes += img.nbytes
+        mem_required = cache_bytes * num_imgs / num_samples
+        return mem_required
+
+
+def cache_read_img(use_cache=True):
+    def decorator(read_img_fn):
+        """
+        Decorate the read_img function to cache the image
+
+        Args:
+            read_img_fn: read_img function
+            use_cache (bool, optional): For the decorated read_img function,
+                whether to read the image from cache.
+                Defaults to True.
+        """
+        @wraps(read_img_fn)
+        def wrapper(self, index, use_cache=use_cache):
+            cache = self.cache and use_cache
+            if cache:
+                if self.cache_type == "ram":
+                    img = self.imgs[index]
+                    img = copy.deepcopy(img)
+                elif self.cache_type == "disk":
+                    img = np.load(
+                        os.path.join(
+                            self.cache_dir, f"{self.path_filename[index].split('.')[0]}.npy"))
+                else:
+                    raise ValueError(f"Unknown cache type: {self.cache_type}")
+            else:
+                img = read_img_fn(self, index)
+            return img
+        return wrapper
+    return decorator
diff --git a/multimodal/YOLOX/yolox/data/datasets/mosaicdetection.py b/multimodal/YOLOX/yolox/data/datasets/mosaicdetection.py
new file mode 100644
index 0000000000000000000000000000000000000000..708babed55086113e9ec69f57e9408b6a28b9422
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/datasets/mosaicdetection.py
@@ -0,0 +1,234 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import random
+
+import cv2
+import numpy as np
+
+from yolox.utils import adjust_box_anns, get_local_rank
+
+from ..data_augment import random_affine
+from .datasets_wrapper import Dataset
+
+
+def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w):
+    # TODO update doc
+    # index0 to top left part of image
+    if mosaic_index == 0:
+        x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
+        small_coord = w - (x2 - x1), h - (y2 - y1), w, h
+    # index1 to top right part of image
+    elif mosaic_index == 1:
+        x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
+        small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
+    # index2 to bottom left part of image
+    elif mosaic_index == 2:
+        x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
+        small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
+    # index2 to bottom right part of image
+    elif mosaic_index == 3:
+        x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h)  # noqa
+        small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
+    return (x1, y1, x2, y2), small_coord
+
+
+class MosaicDetection(Dataset):
+    """Detection dataset wrapper that performs mixup for normal dataset."""
+
+    def __init__(
+        self, dataset, img_size, mosaic=True, preproc=None,
+        degrees=10.0, translate=0.1, mosaic_scale=(0.5, 1.5),
+        mixup_scale=(0.5, 1.5), shear=2.0, enable_mixup=True,
+        mosaic_prob=1.0, mixup_prob=1.0, *args
+    ):
+        """
+
+        Args:
+            dataset(Dataset) : Pytorch dataset object.
+            img_size (tuple):
+            mosaic (bool): enable mosaic augmentation or not.
+            preproc (func):
+            degrees (float):
+            translate (float):
+            mosaic_scale (tuple):
+            mixup_scale (tuple):
+            shear (float):
+            enable_mixup (bool):
+            *args(tuple) : Additional arguments for mixup random sampler.
+        """
+        super().__init__(img_size, mosaic=mosaic)
+        self._dataset = dataset
+        self.preproc = preproc
+        self.degrees = degrees
+        self.translate = translate
+        self.scale = mosaic_scale
+        self.shear = shear
+        self.mixup_scale = mixup_scale
+        self.enable_mosaic = mosaic
+        self.enable_mixup = enable_mixup
+        self.mosaic_prob = mosaic_prob
+        self.mixup_prob = mixup_prob
+        self.local_rank = get_local_rank()
+
+    def __len__(self):
+        return len(self._dataset)
+
+    @Dataset.mosaic_getitem
+    def __getitem__(self, idx):
+        if self.enable_mosaic and random.random() < self.mosaic_prob:
+            mosaic_labels = []
+            input_dim = self._dataset.input_dim
+            input_h, input_w = input_dim[0], input_dim[1]
+
+            # yc, xc = s, s  # mosaic center x, y
+            yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
+            xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
+
+            # 3 additional image indices
+            indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)]
+
+            for i_mosaic, index in enumerate(indices):
+                img, _labels, _, img_id = self._dataset.pull_item(index)
+                h0, w0 = img.shape[:2]  # orig hw
+                scale = min(1. * input_h / h0, 1. * input_w / w0)
+                img = cv2.resize(
+                    img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR
+                )
+                # generate output mosaic image
+                (h, w, c) = img.shape[:3]
+                if i_mosaic == 0:
+                    mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8)
+
+                # suffix l means large image, while s means small image in mosaic aug.
+                (l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(
+                    mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w
+                )
+
+                mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
+                padw, padh = l_x1 - s_x1, l_y1 - s_y1
+
+                labels = _labels.copy()
+                # Normalized xywh to pixel xyxy format
+                if _labels.size > 0:
+                    labels[:, 0] = scale * _labels[:, 0] + padw
+                    labels[:, 1] = scale * _labels[:, 1] + padh
+                    labels[:, 2] = scale * _labels[:, 2] + padw
+                    labels[:, 3] = scale * _labels[:, 3] + padh
+                mosaic_labels.append(labels)
+
+            if len(mosaic_labels):
+                mosaic_labels = np.concatenate(mosaic_labels, 0)
+                np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0])
+                np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1])
+                np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
+                np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
+
+            mosaic_img, mosaic_labels = random_affine(
+                mosaic_img,
+                mosaic_labels,
+                target_size=(input_w, input_h),
+                degrees=self.degrees,
+                translate=self.translate,
+                scales=self.scale,
+                shear=self.shear,
+            )
+
+            # -----------------------------------------------------------------
+            # CopyPaste: https://arxiv.org/abs/2012.07177
+            # -----------------------------------------------------------------
+            if (
+                self.enable_mixup
+                and not len(mosaic_labels) == 0
+                and random.random() < self.mixup_prob
+            ):
+                mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim)
+            mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim)
+            img_info = (mix_img.shape[1], mix_img.shape[0])
+
+            # -----------------------------------------------------------------
+            # img_info and img_id are not used for training.
+            # They are also hard to be specified on a mosaic image.
+            # -----------------------------------------------------------------
+            return mix_img, padded_labels, img_info, img_id
+
+        else:
+            self._dataset._input_dim = self.input_dim
+            img, label, img_info, img_id = self._dataset.pull_item(idx)
+            img, label = self.preproc(img, label, self.input_dim)
+            return img, label, img_info, img_id
+
+    def mixup(self, origin_img, origin_labels, input_dim):
+        jit_factor = random.uniform(*self.mixup_scale)
+        FLIP = random.uniform(0, 1) > 0.5
+        cp_labels = []
+        while len(cp_labels) == 0:
+            cp_index = random.randint(0, self.__len__() - 1)
+            cp_labels = self._dataset.load_anno(cp_index)
+        img, cp_labels, _, _ = self._dataset.pull_item(cp_index)
+
+        if len(img.shape) == 3:
+            cp_img = np.ones((input_dim[0], input_dim[1], 3), dtype=np.uint8) * 114
+        else:
+            cp_img = np.ones(input_dim, dtype=np.uint8) * 114
+
+        cp_scale_ratio = min(input_dim[0] / img.shape[0], input_dim[1] / img.shape[1])
+        resized_img = cv2.resize(
+            img,
+            (int(img.shape[1] * cp_scale_ratio), int(img.shape[0] * cp_scale_ratio)),
+            interpolation=cv2.INTER_LINEAR,
+        )
+
+        cp_img[
+            : int(img.shape[0] * cp_scale_ratio), : int(img.shape[1] * cp_scale_ratio)
+        ] = resized_img
+
+        cp_img = cv2.resize(
+            cp_img,
+            (int(cp_img.shape[1] * jit_factor), int(cp_img.shape[0] * jit_factor)),
+        )
+        cp_scale_ratio *= jit_factor
+
+        if FLIP:
+            cp_img = cp_img[:, ::-1, :]
+
+        origin_h, origin_w = cp_img.shape[:2]
+        target_h, target_w = origin_img.shape[:2]
+        padded_img = np.zeros(
+            (max(origin_h, target_h), max(origin_w, target_w), 3), dtype=np.uint8
+        )
+        padded_img[:origin_h, :origin_w] = cp_img
+
+        x_offset, y_offset = 0, 0
+        if padded_img.shape[0] > target_h:
+            y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
+        if padded_img.shape[1] > target_w:
+            x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
+        padded_cropped_img = padded_img[
+            y_offset: y_offset + target_h, x_offset: x_offset + target_w
+        ]
+
+        cp_bboxes_origin_np = adjust_box_anns(
+            cp_labels[:, :4].copy(), cp_scale_ratio, 0, 0, origin_w, origin_h
+        )
+        if FLIP:
+            cp_bboxes_origin_np[:, 0::2] = (
+                origin_w - cp_bboxes_origin_np[:, 0::2][:, ::-1]
+            )
+        cp_bboxes_transformed_np = cp_bboxes_origin_np.copy()
+        cp_bboxes_transformed_np[:, 0::2] = np.clip(
+            cp_bboxes_transformed_np[:, 0::2] - x_offset, 0, target_w
+        )
+        cp_bboxes_transformed_np[:, 1::2] = np.clip(
+            cp_bboxes_transformed_np[:, 1::2] - y_offset, 0, target_h
+        )
+
+        cls_labels = cp_labels[:, 4:5].copy()
+        box_labels = cp_bboxes_transformed_np
+        labels = np.hstack((box_labels, cls_labels))
+        origin_labels = np.vstack((origin_labels, labels))
+        origin_img = origin_img.astype(np.float32)
+        origin_img = 0.5 * origin_img + 0.5 * padded_cropped_img.astype(np.float32)
+
+        return origin_img.astype(np.uint8), origin_labels
diff --git a/multimodal/YOLOX/yolox/data/datasets/voc.py b/multimodal/YOLOX/yolox/data/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdacd80191bc50b92185b73c97a68d792041feaa
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/datasets/voc.py
@@ -0,0 +1,331 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Code are based on
+# https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
+# Copyright (c) Francisco Massa.
+# Copyright (c) Ellis Brown, Max deGroot.
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+import os.path
+import pickle
+import xml.etree.ElementTree as ET
+
+import cv2
+import numpy as np
+
+from yolox.evaluators.voc_eval import voc_eval
+
+from .datasets_wrapper import CacheDataset, cache_read_img
+from .voc_classes import VOC_CLASSES
+
+
+class AnnotationTransform(object):
+
+    """Transforms a VOC annotation into a Tensor of bbox coords and label index
+    Initilized with a dictionary lookup of classnames to indexes
+
+    Arguments:
+        class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
+            (default: alphabetic indexing of VOC's 20 classes)
+        keep_difficult (bool, optional): keep difficult instances or not
+            (default: False)
+        height (int): height
+        width (int): width
+    """
+
+    def __init__(self, class_to_ind=None, keep_difficult=True):
+        self.class_to_ind = class_to_ind or dict(
+            zip(VOC_CLASSES, range(len(VOC_CLASSES)))
+        )
+        self.keep_difficult = keep_difficult
+
+    def __call__(self, target):
+        """
+        Arguments:
+            target (annotation) : the target annotation to be made usable
+                will be an ET.Element
+        Returns:
+            a list containing lists of bounding boxes  [bbox coords, class name]
+        """
+        res = np.empty((0, 5))
+        for obj in target.iter("object"):
+            difficult = obj.find("difficult")
+            if difficult is not None:
+                difficult = int(difficult.text) == 1
+            else:
+                difficult = False
+            if not self.keep_difficult and difficult:
+                continue
+            name = obj.find("name").text.strip()
+            bbox = obj.find("bndbox")
+
+            pts = ["xmin", "ymin", "xmax", "ymax"]
+            bndbox = []
+            for i, pt in enumerate(pts):
+                cur_pt = int(float(bbox.find(pt).text)) - 1
+                # scale height or width
+                # cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
+                bndbox.append(cur_pt)
+            label_idx = self.class_to_ind[name]
+            bndbox.append(label_idx)
+            res = np.vstack((res, bndbox))  # [xmin, ymin, xmax, ymax, label_ind]
+            # img_id = target.find('filename').text[:-4]
+
+        width = int(target.find("size").find("width").text)
+        height = int(target.find("size").find("height").text)
+        img_info = (height, width)
+
+        return res, img_info
+
+
+class VOCDetection(CacheDataset):
+
+    """
+    VOC Detection Dataset Object
+
+    input is image, target is annotation
+
+    Args:
+        root (string): filepath to VOCdevkit folder.
+        image_set (string): imageset to use (eg. 'train', 'val', 'test')
+        transform (callable, optional): transformation to perform on the
+            input image
+        target_transform (callable, optional): transformation to perform on the
+            target `annotation`
+            (eg: take in caption string, return tensor of word indices)
+        dataset_name (string, optional): which dataset to load
+            (default: 'VOC2007')
+    """
+
+    def __init__(
+        self,
+        data_dir,
+        image_sets=[("2007", "trainval"), ("2012", "trainval")],
+        img_size=(416, 416),
+        preproc=None,
+        target_transform=AnnotationTransform(),
+        dataset_name="VOC0712",
+        cache=False,
+        cache_type="ram",
+    ):
+        self.root = data_dir
+        self.image_set = image_sets
+        self.img_size = img_size
+        self.preproc = preproc
+        self.target_transform = target_transform
+        self.name = dataset_name
+        self._annopath = os.path.join("%s", "Annotations", "%s.xml")
+        self._imgpath = os.path.join("%s", "JPEGImages", "%s.jpg")
+        self._classes = VOC_CLASSES
+        self.cats = [
+            {"id": idx, "name": val} for idx, val in enumerate(VOC_CLASSES)
+        ]
+        self.class_ids = list(range(len(VOC_CLASSES)))
+        self.ids = list()
+        for (year, name) in image_sets:
+            self._year = year
+            rootpath = os.path.join(self.root, "VOC" + year)
+            for line in open(
+                os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
+            ):
+                self.ids.append((rootpath, line.strip()))
+        self.num_imgs = len(self.ids)
+
+        self.annotations = self._load_coco_annotations()
+
+        path_filename = [
+            (self._imgpath % self.ids[i]).split(self.root + "/")[1]
+            for i in range(self.num_imgs)
+        ]
+        super().__init__(
+            input_dimension=img_size,
+            num_imgs=self.num_imgs,
+            data_dir=self.root,
+            cache_dir_name=f"cache_{self.name}",
+            path_filename=path_filename,
+            cache=cache,
+            cache_type=cache_type
+        )
+
+    def __len__(self):
+        return self.num_imgs
+
+    def _load_coco_annotations(self):
+        return [self.load_anno_from_ids(_ids) for _ids in range(self.num_imgs)]
+
+    def load_anno_from_ids(self, index):
+        img_id = self.ids[index]
+        target = ET.parse(self._annopath % img_id).getroot()
+
+        assert self.target_transform is not None
+        res, img_info = self.target_transform(target)
+        height, width = img_info
+
+        r = min(self.img_size[0] / height, self.img_size[1] / width)
+        res[:, :4] *= r
+        resized_info = (int(height * r), int(width * r))
+
+        return (res, img_info, resized_info)
+
+    def load_anno(self, index):
+        return self.annotations[index][0]
+
+    def load_resized_img(self, index):
+        img = self.load_image(index)
+        r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
+        resized_img = cv2.resize(
+            img,
+            (int(img.shape[1] * r), int(img.shape[0] * r)),
+            interpolation=cv2.INTER_LINEAR,
+        ).astype(np.uint8)
+
+        return resized_img
+
+    def load_image(self, index):
+        img_id = self.ids[index]
+        img = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
+        assert img is not None, f"file named {self._imgpath % img_id} not found"
+
+        return img
+
+    @cache_read_img(use_cache=True)
+    def read_img(self, index):
+        return self.load_resized_img(index)
+
+    def pull_item(self, index):
+        """Returns the original image and target at an index for mixup
+
+        Note: not using self.__getitem__(), as any transformations passed in
+        could mess up this functionality.
+
+        Argument:
+            index (int): index of img to show
+        Return:
+            img, target
+        """
+        target, img_info, _ = self.annotations[index]
+        img = self.read_img(index)
+
+        return img, target, img_info, index
+
+    @CacheDataset.mosaic_getitem
+    def __getitem__(self, index):
+        img, target, img_info, img_id = self.pull_item(index)
+
+        if self.preproc is not None:
+            img, target = self.preproc(img, target, self.input_dim)
+
+        return img, target, img_info, img_id
+
+    def evaluate_detections(self, all_boxes, output_dir=None):
+        """
+        all_boxes is a list of length number-of-classes.
+        Each list element is a list of length number-of-images.
+        Each of those list elements is either an empty list []
+        or a numpy array of detection.
+
+        all_boxes[class][image] = [] or np.array of shape #dets x 5
+        """
+        self._write_voc_results_file(all_boxes)
+        IouTh = np.linspace(
+            0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True
+        )
+        mAPs = []
+        for iou in IouTh:
+            mAP = self._do_python_eval(output_dir, iou)
+            mAPs.append(mAP)
+
+        print("--------------------------------------------------------------")
+        print("map_5095:", np.mean(mAPs))
+        print("map_50:", mAPs[0])
+        print("--------------------------------------------------------------")
+        return np.mean(mAPs), mAPs[0]
+
+    def _get_voc_results_file_template(self):
+        filename = "comp4_det_test" + "_{:s}.txt"
+        filedir = os.path.join(self.root, "results", "VOC" + self._year, "Main")
+        if not os.path.exists(filedir):
+            os.makedirs(filedir)
+        path = os.path.join(filedir, filename)
+        return path
+
+    def _write_voc_results_file(self, all_boxes):
+        for cls_ind, cls in enumerate(VOC_CLASSES):
+            cls_ind = cls_ind
+            if cls == "__background__":
+                continue
+            print("Writing {} VOC results file".format(cls))
+            filename = self._get_voc_results_file_template().format(cls)
+            with open(filename, "wt") as f:
+                for im_ind, index in enumerate(self.ids):
+                    index = index[1]
+                    dets = all_boxes[cls_ind][im_ind]
+                    if dets == []:
+                        continue
+                    for k in range(dets.shape[0]):
+                        f.write(
+                            "{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format(
+                                index,
+                                dets[k, -1],
+                                dets[k, 0] + 1,
+                                dets[k, 1] + 1,
+                                dets[k, 2] + 1,
+                                dets[k, 3] + 1,
+                            )
+                        )
+
+    def _do_python_eval(self, output_dir="output", iou=0.5):
+        rootpath = os.path.join(self.root, "VOC" + self._year)
+        name = self.image_set[0][1]
+        annopath = os.path.join(rootpath, "Annotations", "{:s}.xml")
+        imagesetfile = os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
+        cachedir = os.path.join(
+            self.root, "annotations_cache", "VOC" + self._year, name
+        )
+        if not os.path.exists(cachedir):
+            os.makedirs(cachedir)
+        aps = []
+        # The PASCAL VOC metric changed in 2010
+        use_07_metric = True if int(self._year) < 2010 else False
+        print("Eval IoU : {:.2f}".format(iou))
+        if output_dir is not None and not os.path.isdir(output_dir):
+            os.mkdir(output_dir)
+        for i, cls in enumerate(VOC_CLASSES):
+
+            if cls == "__background__":
+                continue
+
+            filename = self._get_voc_results_file_template().format(cls)
+            rec, prec, ap = voc_eval(
+                filename,
+                annopath,
+                imagesetfile,
+                cls,
+                cachedir,
+                ovthresh=iou,
+                use_07_metric=use_07_metric,
+            )
+            aps += [ap]
+            if iou == 0.5:
+                print("AP for {} = {:.4f}".format(cls, ap))
+            if output_dir is not None:
+                with open(os.path.join(output_dir, cls + "_pr.pkl"), "wb") as f:
+                    pickle.dump({"rec": rec, "prec": prec, "ap": ap}, f)
+        if iou == 0.5:
+            print("Mean AP = {:.4f}".format(np.mean(aps)))
+            print("~~~~~~~~")
+            print("Results:")
+            for ap in aps:
+                print("{:.3f}".format(ap))
+            print("{:.3f}".format(np.mean(aps)))
+            print("~~~~~~~~")
+            print("")
+            print("--------------------------------------------------------------")
+            print("Results computed with the **unofficial** Python eval code.")
+            print("Results should be very close to the official MATLAB eval code.")
+            print("Recompute with `./tools/reval.py --matlab ...` for your paper.")
+            print("-- Thanks, The Management")
+            print("--------------------------------------------------------------")
+
+        return np.mean(aps)
diff --git a/multimodal/YOLOX/yolox/data/datasets/voc_classes.py b/multimodal/YOLOX/yolox/data/datasets/voc_classes.py
new file mode 100644
index 0000000000000000000000000000000000000000..89354b3fdb19195f63f76ed56c86565323de5434
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/datasets/voc_classes.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+# VOC_CLASSES = ( '__background__', # always index 0
+VOC_CLASSES = (
+    "aeroplane",
+    "bicycle",
+    "bird",
+    "boat",
+    "bottle",
+    "bus",
+    "car",
+    "cat",
+    "chair",
+    "cow",
+    "diningtable",
+    "dog",
+    "horse",
+    "motorbike",
+    "person",
+    "pottedplant",
+    "sheep",
+    "sofa",
+    "train",
+    "tvmonitor",
+)
diff --git a/multimodal/YOLOX/yolox/data/samplers.py b/multimodal/YOLOX/yolox/data/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b7ea38d3cd5bc0c906229b48ceaa51483173c42
--- /dev/null
+++ b/multimodal/YOLOX/yolox/data/samplers.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import itertools
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import BatchSampler as torchBatchSampler
+from torch.utils.data.sampler import Sampler
+
+
+class YoloBatchSampler(torchBatchSampler):
+    """
+    This batch sampler will generate mini-batches of (mosaic, index) tuples from another sampler.
+    It works just like the :class:`torch.utils.data.sampler.BatchSampler`,
+    but it will turn on/off the mosaic aug.
+    """
+
+    def __init__(self, *args, mosaic=True, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.mosaic = mosaic
+
+    def __iter__(self):
+        for batch in super().__iter__():
+            yield [(self.mosaic, idx) for idx in batch]
+
+
+class InfiniteSampler(Sampler):
+    """
+    In training, we only care about the "infinite stream" of training data.
+    So this sampler produces an infinite stream of indices and
+    all workers cooperate to correctly shuffle the indices and sample different indices.
+    The samplers in each worker effectively produces `indices[worker_id::num_workers]`
+    where `indices` is an infinite stream of indices consisting of
+    `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
+    or `range(size) + range(size) + ...` (if shuffle is False)
+    """
+
+    def __init__(
+        self,
+        size: int,
+        shuffle: bool = True,
+        seed: Optional[int] = 0,
+        rank=0,
+        world_size=1,
+    ):
+        """
+        Args:
+            size (int): the total number of data of the underlying dataset to sample from
+            shuffle (bool): whether to shuffle the indices or not
+            seed (int): the initial seed of the shuffle. Must be the same
+                across all workers. If None, will use a random seed shared
+                among workers (require synchronization among all workers).
+        """
+        self._size = size
+        assert size > 0
+        self._shuffle = shuffle
+        self._seed = int(seed)
+
+        if dist.is_available() and dist.is_initialized():
+            self._rank = dist.get_rank()
+            self._world_size = dist.get_world_size()
+        else:
+            self._rank = rank
+            self._world_size = world_size
+
+    def __iter__(self):
+        start = self._rank
+        yield from itertools.islice(
+            self._infinite_indices(), start, None, self._world_size
+        )
+
+    def _infinite_indices(self):
+        g = torch.Generator()
+        g.manual_seed(self._seed)
+        while True:
+            if self._shuffle:
+                yield from torch.randperm(self._size, generator=g)
+            else:
+                yield from torch.arange(self._size)
+
+    def __len__(self):
+        return self._size // self._world_size
diff --git a/multimodal/YOLOX/yolox/evaluators/__init__.py b/multimodal/YOLOX/yolox/evaluators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a99047b4bcd5cfba68540fd94ee80926bb0044b
--- /dev/null
+++ b/multimodal/YOLOX/yolox/evaluators/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+from .coco_evaluator import COCOEvaluator
+from .voc_evaluator import VOCEvaluator
diff --git a/multimodal/YOLOX/yolox/evaluators/coco_evaluator.py b/multimodal/YOLOX/yolox/evaluators/coco_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e218c745624e5330dbae37dcac60f83052bf2f31
--- /dev/null
+++ b/multimodal/YOLOX/yolox/evaluators/coco_evaluator.py
@@ -0,0 +1,317 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import contextlib
+import io
+import itertools
+import json
+import tempfile
+import time
+from collections import ChainMap, defaultdict
+from loguru import logger
+from tabulate import tabulate
+from tqdm import tqdm
+
+import numpy as np
+
+import torch
+
+from yolox.data.datasets import COCO_CLASSES
+from yolox.utils import (
+    gather,
+    is_main_process,
+    postprocess,
+    synchronize,
+    time_synchronized,
+    xyxy2xywh
+)
+
+
+def per_class_AR_table(coco_eval, class_names=COCO_CLASSES, headers=["class", "AR"], colums=6):
+    per_class_AR = {}
+    recalls = coco_eval.eval["recall"]
+    # dimension of recalls: [TxKxAxM]
+    # recall has dims (iou, cls, area range, max dets)
+    assert len(class_names) == recalls.shape[1]
+
+    for idx, name in enumerate(class_names):
+        recall = recalls[:, idx, 0, -1]
+        recall = recall[recall > -1]
+        ar = np.mean(recall) if recall.size else float("nan")
+        per_class_AR[name] = float(ar * 100)
+
+    num_cols = min(colums, len(per_class_AR) * len(headers))
+    result_pair = [x for pair in per_class_AR.items() for x in pair]
+    row_pair = itertools.zip_longest(*[result_pair[i::num_cols] for i in range(num_cols)])
+    table_headers = headers * (num_cols // len(headers))
+    table = tabulate(
+        row_pair, tablefmt="pipe", floatfmt=".3f", headers=table_headers, numalign="left",
+    )
+    return table
+
+
+def per_class_AP_table(coco_eval, class_names=COCO_CLASSES, headers=["class", "AP"], colums=6):
+    per_class_AP = {}
+    precisions = coco_eval.eval["precision"]
+    # dimension of precisions: [TxRxKxAxM]
+    # precision has dims (iou, recall, cls, area range, max dets)
+    assert len(class_names) == precisions.shape[2]
+
+    for idx, name in enumerate(class_names):
+        # area range index 0: all area ranges
+        # max dets index -1: typically 100 per image
+        precision = precisions[:, :, idx, 0, -1]
+        precision = precision[precision > -1]
+        ap = np.mean(precision) if precision.size else float("nan")
+        per_class_AP[name] = float(ap * 100)
+
+    num_cols = min(colums, len(per_class_AP) * len(headers))
+    result_pair = [x for pair in per_class_AP.items() for x in pair]
+    row_pair = itertools.zip_longest(*[result_pair[i::num_cols] for i in range(num_cols)])
+    table_headers = headers * (num_cols // len(headers))
+    table = tabulate(
+        row_pair, tablefmt="pipe", floatfmt=".3f", headers=table_headers, numalign="left",
+    )
+    return table
+
+
+class COCOEvaluator:
+    """
+    COCO AP Evaluation class.  All the data in the val2017 dataset are processed
+    and evaluated by COCO API.
+    """
+
+    def __init__(
+        self,
+        dataloader,
+        img_size: int,
+        confthre: float,
+        nmsthre: float,
+        num_classes: int,
+        testdev: bool = False,
+        per_class_AP: bool = True,
+        per_class_AR: bool = True,
+    ):
+        """
+        Args:
+            dataloader (Dataloader): evaluate dataloader.
+            img_size: image size after preprocess. images are resized
+                to squares whose shape is (img_size, img_size).
+            confthre: confidence threshold ranging from 0 to 1, which
+                is defined in the config file.
+            nmsthre: IoU threshold of non-max supression ranging from 0 to 1.
+            per_class_AP: Show per class AP during evalution or not. Default to True.
+            per_class_AR: Show per class AR during evalution or not. Default to True.
+        """
+        self.dataloader = dataloader
+        self.img_size = img_size
+        self.confthre = confthre
+        self.nmsthre = nmsthre
+        self.num_classes = num_classes
+        self.testdev = testdev
+        self.per_class_AP = per_class_AP
+        self.per_class_AR = per_class_AR
+
+    def evaluate(
+        self, model, distributed=False, half=False, trt_file=None,
+        decoder=None, test_size=None, return_outputs=False
+    ):
+        """
+        COCO average precision (AP) Evaluation. Iterate inference on the test dataset
+        and the results are evaluated by COCO API.
+
+        NOTE: This function will change training mode to False, please save states if needed.
+
+        Args:
+            model : model to evaluate.
+
+        Returns:
+            ap50_95 (float) : COCO AP of IoU=50:95
+            ap50 (float) : COCO AP of IoU=50
+            summary (sr): summary info of evaluation.
+        """
+        # TODO half to amp_test
+        tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
+        model = model.eval()
+        if half:
+            model = model.half()
+        ids = []
+        data_list = []
+        output_data = defaultdict()
+        progress_bar = tqdm if is_main_process() else iter
+
+        inference_time = 0
+        nms_time = 0
+        n_samples = max(len(self.dataloader) - 1, 1)
+
+        if trt_file is not None:
+            from torch2trt import TRTModule
+
+            model_trt = TRTModule()
+            model_trt.load_state_dict(torch.load(trt_file))
+
+            x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
+            model(x)
+            model = model_trt
+
+        for cur_iter, (imgs, _, info_imgs, ids) in enumerate(
+            progress_bar(self.dataloader)
+        ):
+            with torch.no_grad():
+                imgs = imgs.type(tensor_type)
+
+                # skip the last iters since batchsize might be not enough for batch inference
+                is_time_record = cur_iter < len(self.dataloader) - 1
+                if is_time_record:
+                    start = time.time()
+
+                outputs = model(imgs)
+                if decoder is not None:
+                    outputs = decoder(outputs, dtype=outputs.type())
+
+                if is_time_record:
+                    infer_end = time_synchronized()
+                    inference_time += infer_end - start
+
+                outputs = postprocess(
+                    outputs, self.num_classes, self.confthre, self.nmsthre
+                )
+                if is_time_record:
+                    nms_end = time_synchronized()
+                    nms_time += nms_end - infer_end
+
+            data_list_elem, image_wise_data = self.convert_to_coco_format(
+                outputs, info_imgs, ids, return_outputs=True)
+            data_list.extend(data_list_elem)
+            output_data.update(image_wise_data)
+
+        statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
+        if distributed:
+            # different process/device might have different speed,
+            # to make sure the process will not be stucked, sync func is used here.
+            synchronize()
+            data_list = gather(data_list, dst=0)
+            output_data = gather(output_data, dst=0)
+            data_list = list(itertools.chain(*data_list))
+            output_data = dict(ChainMap(*output_data))
+            torch.distributed.reduce(statistics, dst=0)
+
+        eval_results = self.evaluate_prediction(data_list, statistics)
+        synchronize()
+
+        if return_outputs:
+            return eval_results, output_data
+        return eval_results
+
+    def convert_to_coco_format(self, outputs, info_imgs, ids, return_outputs=False):
+        data_list = []
+        image_wise_data = defaultdict(dict)
+        for (output, img_h, img_w, img_id) in zip(
+            outputs, info_imgs[0], info_imgs[1], ids
+        ):
+            if output is None:
+                continue
+            output = output.cpu()
+
+            bboxes = output[:, 0:4]
+
+            # preprocessing: resize
+            scale = min(
+                self.img_size[0] / float(img_h), self.img_size[1] / float(img_w)
+            )
+            bboxes /= scale
+            cls = output[:, 6]
+            scores = output[:, 4] * output[:, 5]
+
+            image_wise_data.update({
+                int(img_id): {
+                    "bboxes": [box.numpy().tolist() for box in bboxes],
+                    "scores": [score.numpy().item() for score in scores],
+                    "categories": [
+                        self.dataloader.dataset.class_ids[int(cls[ind])]
+                        for ind in range(bboxes.shape[0])
+                    ],
+                }
+            })
+
+            bboxes = xyxy2xywh(bboxes)
+
+            for ind in range(bboxes.shape[0]):
+                label = self.dataloader.dataset.class_ids[int(cls[ind])]
+                pred_data = {
+                    "image_id": int(img_id),
+                    "category_id": label,
+                    "bbox": bboxes[ind].numpy().tolist(),
+                    "score": scores[ind].numpy().item(),
+                    "segmentation": [],
+                }  # COCO json format
+                data_list.append(pred_data)
+
+        if return_outputs:
+            return data_list, image_wise_data
+        return data_list
+
+    def evaluate_prediction(self, data_dict, statistics):
+        if not is_main_process():
+            return 0, 0, None
+
+        logger.info("Evaluate in main process...")
+
+        annType = ["segm", "bbox", "keypoints"]
+
+        inference_time = statistics[0].item()
+        nms_time = statistics[1].item()
+        n_samples = statistics[2].item()
+
+        a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
+        a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
+
+        time_info = ", ".join(
+            [
+                "Average {} time: {:.2f} ms".format(k, v)
+                for k, v in zip(
+                    ["forward", "NMS", "inference"],
+                    [a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
+                )
+            ]
+        )
+
+        info = time_info + "\n"
+
+        # Evaluate the Dt (detection) json comparing with the ground truth
+        if len(data_dict) > 0:
+            cocoGt = self.dataloader.dataset.coco
+            # TODO: since pycocotools can't process dict in py36, write data to json file.
+            if self.testdev:
+                json.dump(data_dict, open("./yolox_testdev_2017.json", "w"))
+                cocoDt = cocoGt.loadRes("./yolox_testdev_2017.json")
+            else:
+                _, tmp = tempfile.mkstemp()
+                json.dump(data_dict, open(tmp, "w"))
+                cocoDt = cocoGt.loadRes(tmp)
+            try:
+                from yolox.layers import COCOeval_opt as COCOeval
+            except ImportError:
+                from pycocotools.cocoeval import COCOeval
+
+                logger.warning("Use standard COCOeval.")
+
+            cocoEval = COCOeval(cocoGt, cocoDt, annType[1])
+            cocoEval.evaluate()
+            cocoEval.accumulate()
+            redirect_string = io.StringIO()
+            with contextlib.redirect_stdout(redirect_string):
+                cocoEval.summarize()
+            info += redirect_string.getvalue()
+            cat_ids = list(cocoGt.cats.keys())
+            cat_names = [cocoGt.cats[catId]['name'] for catId in sorted(cat_ids)]
+            if self.per_class_AP:
+                AP_table = per_class_AP_table(cocoEval, class_names=cat_names)
+                info += "per class AP:\n" + AP_table + "\n"
+            if self.per_class_AR:
+                AR_table = per_class_AR_table(cocoEval, class_names=cat_names)
+                info += "per class AR:\n" + AR_table + "\n"
+            return cocoEval.stats[0], cocoEval.stats[1], info
+        else:
+            return 0, 0, info
diff --git a/multimodal/YOLOX/yolox/evaluators/voc_eval.py b/multimodal/YOLOX/yolox/evaluators/voc_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1a474861e0a760c1e180dc62803100f030458bd
--- /dev/null
+++ b/multimodal/YOLOX/yolox/evaluators/voc_eval.py
@@ -0,0 +1,183 @@
+#!/usr/bin/env python3
+# Code are based on
+# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py
+# Copyright (c) Bharath Hariharan.
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import os
+import pickle
+import xml.etree.ElementTree as ET
+
+import numpy as np
+
+
+def parse_rec(filename):
+    """Parse a PASCAL VOC xml file"""
+    tree = ET.parse(filename)
+    objects = []
+    for obj in tree.findall("object"):
+        obj_struct = {}
+        obj_struct["name"] = obj.find("name").text
+        obj_struct["pose"] = obj.find("pose").text
+        obj_struct["truncated"] = int(obj.find("truncated").text)
+        obj_struct["difficult"] = int(obj.find("difficult").text)
+        bbox = obj.find("bndbox")
+        obj_struct["bbox"] = [
+            int(bbox.find("xmin").text),
+            int(bbox.find("ymin").text),
+            int(bbox.find("xmax").text),
+            int(bbox.find("ymax").text),
+        ]
+        objects.append(obj_struct)
+
+    return objects
+
+
+def voc_ap(rec, prec, use_07_metric=False):
+    """
+    Compute VOC AP given precision and recall.
+    If use_07_metric is true, uses the
+    VOC 07 11 point method (default:False).
+    """
+    if use_07_metric:
+        # 11 point metric
+        ap = 0.0
+        for t in np.arange(0.0, 1.1, 0.1):
+            if np.sum(rec >= t) == 0:
+                p = 0
+            else:
+                p = np.max(prec[rec >= t])
+            ap = ap + p / 11.0
+    else:
+        # correct AP calculation
+        # first append sentinel values at the end
+        mrec = np.concatenate(([0.0], rec, [1.0]))
+        mpre = np.concatenate(([0.0], prec, [0.0]))
+
+        # compute the precision envelope
+        for i in range(mpre.size - 1, 0, -1):
+            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+        # to calculate area under PR curve, look for points
+        # where X axis (recall) changes value
+        i = np.where(mrec[1:] != mrec[:-1])[0]
+
+        # and sum (\Delta recall) * prec
+        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+    return ap
+
+
+def voc_eval(
+    detpath,
+    annopath,
+    imagesetfile,
+    classname,
+    cachedir,
+    ovthresh=0.5,
+    use_07_metric=False,
+):
+    # first load gt
+    if not os.path.isdir(cachedir):
+        os.mkdir(cachedir)
+    cachefile = os.path.join(cachedir, "annots.pkl")
+    # read list of images
+    with open(imagesetfile, "r") as f:
+        lines = f.readlines()
+    imagenames = [x.strip() for x in lines]
+
+    if not os.path.isfile(cachefile):
+        # load annots
+        recs = {}
+        for i, imagename in enumerate(imagenames):
+            recs[imagename] = parse_rec(annopath.format(imagename))
+            if i % 100 == 0:
+                print(f"Reading annotation for {i + 1}/{len(imagenames)}")
+        # save
+        print(f"Saving cached annotations to {cachefile}")
+        with open(cachefile, "wb") as f:
+            pickle.dump(recs, f)
+    else:
+        # load
+        with open(cachefile, "rb") as f:
+            recs = pickle.load(f)
+
+    # extract gt objects for this class
+    class_recs = {}
+    npos = 0
+    for imagename in imagenames:
+        R = [obj for obj in recs[imagename] if obj["name"] == classname]
+        bbox = np.array([x["bbox"] for x in R])
+        difficult = np.array([x["difficult"] for x in R]).astype(bool)
+        det = [False] * len(R)
+        npos = npos + sum(~difficult)
+        class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}
+
+    # read dets
+    detfile = detpath.format(classname)
+    with open(detfile, "r") as f:
+        lines = f.readlines()
+
+    if len(lines) == 0:
+        return 0, 0, 0
+
+    splitlines = [x.strip().split(" ") for x in lines]
+    image_ids = [x[0] for x in splitlines]
+    confidence = np.array([float(x[1]) for x in splitlines])
+    BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
+
+    # sort by confidence
+    sorted_ind = np.argsort(-confidence)
+    BB = BB[sorted_ind, :]
+    image_ids = [image_ids[x] for x in sorted_ind]
+
+    # go down dets and mark TPs and FPs
+    nd = len(image_ids)
+    tp = np.zeros(nd)
+    fp = np.zeros(nd)
+    for d in range(nd):
+        R = class_recs[image_ids[d]]
+        bb = BB[d, :].astype(float)
+        ovmax = -np.inf
+        BBGT = R["bbox"].astype(float)
+
+        if BBGT.size > 0:
+            # compute overlaps
+            # intersection
+            ixmin = np.maximum(BBGT[:, 0], bb[0])
+            iymin = np.maximum(BBGT[:, 1], bb[1])
+            ixmax = np.minimum(BBGT[:, 2], bb[2])
+            iymax = np.minimum(BBGT[:, 3], bb[3])
+            iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
+            ih = np.maximum(iymax - iymin + 1.0, 0.0)
+            inters = iw * ih
+
+            # union
+            uni = (
+                (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
+                + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0) - inters
+            )
+
+            overlaps = inters / uni
+            ovmax = np.max(overlaps)
+            jmax = np.argmax(overlaps)
+
+        if ovmax > ovthresh:
+            if not R["difficult"][jmax]:
+                if not R["det"][jmax]:
+                    tp[d] = 1.0
+                    R["det"][jmax] = 1
+                else:
+                    fp[d] = 1.0
+        else:
+            fp[d] = 1.0
+
+        # compute precision recall
+    fp = np.cumsum(fp)
+    tp = np.cumsum(tp)
+    rec = tp / float(npos)
+    # avoid divide by zero in case the first detection matches a difficult
+    # ground truth
+    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+    ap = voc_ap(rec, prec, use_07_metric)
+
+    return rec, prec, ap
diff --git a/multimodal/YOLOX/yolox/evaluators/voc_evaluator.py b/multimodal/YOLOX/yolox/evaluators/voc_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..094df3d6978abc39af9fc5d28ceb3548fa9a0417
--- /dev/null
+++ b/multimodal/YOLOX/yolox/evaluators/voc_evaluator.py
@@ -0,0 +1,187 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii, Inc. and its affiliates.
+
+import sys
+import tempfile
+import time
+from collections import ChainMap
+from loguru import logger
+from tqdm import tqdm
+
+import numpy as np
+
+import torch
+
+from yolox.utils import gather, is_main_process, postprocess, synchronize, time_synchronized
+
+
+class VOCEvaluator:
+    """
+    VOC AP Evaluation class.
+    """
+
+    def __init__(self, dataloader, img_size, confthre, nmsthre, num_classes):
+        """
+        Args:
+            dataloader (Dataloader): evaluate dataloader.
+            img_size (int): image size after preprocess. images are resized
+                to squares whose shape is (img_size, img_size).
+            confthre (float): confidence threshold ranging from 0 to 1, which
+                is defined in the config file.
+            nmsthre (float): IoU threshold of non-max supression ranging from 0 to 1.
+        """
+        self.dataloader = dataloader
+        self.img_size = img_size
+        self.confthre = confthre
+        self.nmsthre = nmsthre
+        self.num_classes = num_classes
+        self.num_images = len(dataloader.dataset)
+
+    def evaluate(
+        self, model, distributed=False, half=False, trt_file=None,
+        decoder=None, test_size=None, return_outputs=False,
+    ):
+        """
+        VOC average precision (AP) Evaluation. Iterate inference on the test dataset
+        and the results are evaluated by COCO API.
+
+        NOTE: This function will change training mode to False, please save states if needed.
+
+        Args:
+            model : model to evaluate.
+
+        Returns:
+            ap50_95 (float) : COCO style AP of IoU=50:95
+            ap50 (float) : VOC 2007 metric AP of IoU=50
+            summary (sr): summary info of evaluation.
+        """
+        # TODO half to amp_test
+        tensor_type = torch.cuda.HalfTensor if half else torch.cuda.FloatTensor
+        model = model.eval()
+        if half:
+            model = model.half()
+        ids = []
+        data_list = {}
+        progress_bar = tqdm if is_main_process() else iter
+
+        inference_time = 0
+        nms_time = 0
+        n_samples = max(len(self.dataloader) - 1, 1)
+
+        if trt_file is not None:
+            from torch2trt import TRTModule
+
+            model_trt = TRTModule()
+            model_trt.load_state_dict(torch.load(trt_file))
+
+            x = torch.ones(1, 3, test_size[0], test_size[1]).cuda()
+            model(x)
+            model = model_trt
+
+        for cur_iter, (imgs, _, info_imgs, ids) in enumerate(progress_bar(self.dataloader)):
+            with torch.no_grad():
+                imgs = imgs.type(tensor_type)
+
+                # skip the last iters since batchsize might be not enough for batch inference
+                is_time_record = cur_iter < len(self.dataloader) - 1
+                if is_time_record:
+                    start = time.time()
+
+                outputs = model(imgs)
+                if decoder is not None:
+                    outputs = decoder(outputs, dtype=outputs.type())
+
+                if is_time_record:
+                    infer_end = time_synchronized()
+                    inference_time += infer_end - start
+
+                outputs = postprocess(
+                    outputs, self.num_classes, self.confthre, self.nmsthre
+                )
+                if is_time_record:
+                    nms_end = time_synchronized()
+                    nms_time += nms_end - infer_end
+
+            data_list.update(self.convert_to_voc_format(outputs, info_imgs, ids))
+
+        statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
+        if distributed:
+            data_list = gather(data_list, dst=0)
+            data_list = ChainMap(*data_list)
+            torch.distributed.reduce(statistics, dst=0)
+
+        eval_results = self.evaluate_prediction(data_list, statistics)
+        synchronize()
+        if return_outputs:
+            return eval_results, data_list
+        return eval_results
+
+    def convert_to_voc_format(self, outputs, info_imgs, ids):
+        predictions = {}
+        for output, img_h, img_w, img_id in zip(outputs, info_imgs[0], info_imgs[1], ids):
+            if output is None:
+                predictions[int(img_id)] = (None, None, None)
+                continue
+            output = output.cpu()
+
+            bboxes = output[:, 0:4]
+
+            # preprocessing: resize
+            scale = min(self.img_size[0] / float(img_h), self.img_size[1] / float(img_w))
+            bboxes /= scale
+
+            cls = output[:, 6]
+            scores = output[:, 4] * output[:, 5]
+
+            predictions[int(img_id)] = (bboxes, cls, scores)
+        return predictions
+
+    def evaluate_prediction(self, data_dict, statistics):
+        if not is_main_process():
+            return 0, 0, None
+
+        logger.info("Evaluate in main process...")
+
+        inference_time = statistics[0].item()
+        nms_time = statistics[1].item()
+        n_samples = statistics[2].item()
+
+        a_infer_time = 1000 * inference_time / (n_samples * self.dataloader.batch_size)
+        a_nms_time = 1000 * nms_time / (n_samples * self.dataloader.batch_size)
+
+        time_info = ", ".join(
+            [
+                "Average {} time: {:.2f} ms".format(k, v)
+                for k, v in zip(
+                    ["forward", "NMS", "inference"],
+                    [a_infer_time, a_nms_time, (a_infer_time + a_nms_time)],
+                )
+            ]
+        )
+        info = time_info + "\n"
+
+        all_boxes = [
+            [[] for _ in range(self.num_images)] for _ in range(self.num_classes)
+        ]
+        for img_num in range(self.num_images):
+            bboxes, cls, scores = data_dict[img_num]
+            if bboxes is None:
+                for j in range(self.num_classes):
+                    all_boxes[j][img_num] = np.empty([0, 5], dtype=np.float32)
+                continue
+            for j in range(self.num_classes):
+                mask_c = cls == j
+                if sum(mask_c) == 0:
+                    all_boxes[j][img_num] = np.empty([0, 5], dtype=np.float32)
+                    continue
+
+                c_dets = torch.cat((bboxes, scores.unsqueeze(1)), dim=1)
+                all_boxes[j][img_num] = c_dets[mask_c].numpy()
+
+            sys.stdout.write(f"im_eval: {img_num + 1}/{self.num_images} \r")
+            sys.stdout.flush()
+
+        with tempfile.TemporaryDirectory() as tempdir:
+            mAP50, mAP70 = self.dataloader.dataset.evaluate_detections(all_boxes, tempdir)
+            return mAP50, mAP70, info
diff --git a/multimodal/YOLOX/yolox/exp/__init__.py b/multimodal/YOLOX/yolox/exp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..40e5f58df9aeeb9590a9de66f5a2150bf1a37273
--- /dev/null
+++ b/multimodal/YOLOX/yolox/exp/__init__.py
@@ -0,0 +1,6 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii Inc. All rights reserved.
+
+from .base_exp import BaseExp
+from .build import get_exp
+from .yolox_base import Exp, check_exp_value
diff --git a/multimodal/YOLOX/yolox/exp/base_exp.py b/multimodal/YOLOX/yolox/exp/base_exp.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ccfec5c255f0e27894165a99d5f45383560a89e
--- /dev/null
+++ b/multimodal/YOLOX/yolox/exp/base_exp.py
@@ -0,0 +1,90 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import ast
+import pprint
+from abc import ABCMeta, abstractmethod
+from typing import Dict, List, Tuple
+from tabulate import tabulate
+
+import torch
+from torch.nn import Module
+
+from yolox.utils import LRScheduler
+
+
+class BaseExp(metaclass=ABCMeta):
+    """Basic class for any experiment."""
+
+    def __init__(self):
+        self.seed = None
+        self.output_dir = "./YOLOX_outputs"
+        self.print_interval = 100
+        self.eval_interval = 10
+        self.dataset = None
+
+    @abstractmethod
+    def get_model(self) -> Module:
+        pass
+
+    @abstractmethod
+    def get_dataset(self, cache: bool = False, cache_type: str = "ram"):
+        pass
+
+    @abstractmethod
+    def get_data_loader(
+        self, batch_size: int, is_distributed: bool
+    ) -> Dict[str, torch.utils.data.DataLoader]:
+        pass
+
+    @abstractmethod
+    def get_optimizer(self, batch_size: int) -> torch.optim.Optimizer:
+        pass
+
+    @abstractmethod
+    def get_lr_scheduler(
+        self, lr: float, iters_per_epoch: int, **kwargs
+    ) -> LRScheduler:
+        pass
+
+    @abstractmethod
+    def get_evaluator(self):
+        pass
+
+    @abstractmethod
+    def eval(self, model, evaluator, weights):
+        pass
+
+    def __repr__(self):
+        table_header = ["keys", "values"]
+        exp_table = [
+            (str(k), pprint.pformat(v))
+            for k, v in vars(self).items()
+            if not k.startswith("_")
+        ]
+        return tabulate(exp_table, headers=table_header, tablefmt="fancy_grid")
+
+    def merge(self, cfg_list):
+        assert len(cfg_list) % 2 == 0, f"length must be even, check value here: {cfg_list}"
+        for k, v in zip(cfg_list[0::2], cfg_list[1::2]):
+            # only update value with same key
+            if hasattr(self, k):
+                src_value = getattr(self, k)
+                src_type = type(src_value)
+
+                # pre-process input if source type is list or tuple
+                if isinstance(src_value, (List, Tuple)):
+                    v = v.strip("[]()")
+                    v = [t.strip() for t in v.split(",")]
+
+                    # find type of tuple
+                    if len(src_value) > 0:
+                        src_item_type = type(src_value[0])
+                        v = [src_item_type(t) for t in v]
+
+                if src_value is not None and src_type != type(v):
+                    try:
+                        v = src_type(v)
+                    except Exception:
+                        v = ast.literal_eval(v)
+                setattr(self, k, v)
diff --git a/multimodal/YOLOX/yolox/exp/build.py b/multimodal/YOLOX/yolox/exp/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef83f76facc21677b1e238a4798304357a04832a
--- /dev/null
+++ b/multimodal/YOLOX/yolox/exp/build.py
@@ -0,0 +1,42 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import importlib
+import os
+import sys
+
+
+def get_exp_by_file(exp_file):
+    try:
+        sys.path.append(os.path.dirname(exp_file))
+        current_exp = importlib.import_module(os.path.basename(exp_file).split(".")[0])
+        exp = current_exp.Exp()
+    except Exception:
+        raise ImportError("{} doesn't contains class named 'Exp'".format(exp_file))
+    return exp
+
+
+def get_exp_by_name(exp_name):
+    exp = exp_name.replace("-", "_")  # convert string like "yolox-s" to "yolox_s"
+    module_name = ".".join(["yolox", "exp", "default", exp])
+    exp_object = importlib.import_module(module_name).Exp()
+    return exp_object
+
+
+def get_exp(exp_file=None, exp_name=None):
+    """
+    get Exp object by file or name. If exp_file and exp_name
+    are both provided, get Exp by exp_file.
+
+    Args:
+        exp_file (str): file path of experiment.
+        exp_name (str): name of experiment. "yolo-s",
+    """
+    assert (
+        exp_file is not None or exp_name is not None
+    ), "plz provide exp file or exp name."
+    if exp_file is not None:
+        return get_exp_by_file(exp_file)
+    else:
+        return get_exp_by_name(exp_name)
diff --git a/multimodal/YOLOX/yolox/exp/default/__init__.py b/multimodal/YOLOX/yolox/exp/default/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..68a1d1f0fc58ef34f12134dd20e592ddf7c53878
--- /dev/null
+++ b/multimodal/YOLOX/yolox/exp/default/__init__.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+# This file is used for package installation and find default exp file
+
+import sys
+from importlib import abc, util
+from pathlib import Path
+
+_EXP_PATH = Path(__file__).resolve().parent.parent.parent.parent / "exps" / "default"
+
+if _EXP_PATH.is_dir():
+    # This is true only for in-place installation (pip install -e, setup.py develop),
+    # where setup(package_dir=) does not work: https://github.com/pypa/setuptools/issues/230
+
+    class _ExpFinder(abc.MetaPathFinder):
+        
+        def find_spec(self, name, path, target=None):
+            if not name.startswith("yolox.exp.default"):
+                return
+            project_name = name.split(".")[-1] + ".py"
+            target_file = _EXP_PATH / project_name
+            if not target_file.is_file():
+                return
+            return util.spec_from_file_location(name, target_file)
+
+    sys.meta_path.append(_ExpFinder())
diff --git a/multimodal/YOLOX/yolox/exp/yolox_base.py b/multimodal/YOLOX/yolox/exp/yolox_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..82e93c21bded09a835ce9d27957020bf849a4ae9
--- /dev/null
+++ b/multimodal/YOLOX/yolox/exp/yolox_base.py
@@ -0,0 +1,358 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import os
+import random
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+from .base_exp import BaseExp
+
+__all__ = ["Exp", "check_exp_value"]
+
+
+class Exp(BaseExp):
+    def __init__(self):
+        super().__init__()
+
+        # ---------------- model config ---------------- #
+        # detect classes number of model
+        self.num_classes = 80
+        # factor of model depth
+        self.depth = 1.00
+        # factor of model width
+        self.width = 1.00
+        # activation name. For example, if using "relu", then "silu" will be replaced to "relu".
+        self.act = "silu"
+
+        # ---------------- dataloader config ---------------- #
+        # set worker to 4 for shorter dataloader init time
+        # If your training process cost many memory, reduce this value.
+        self.data_num_workers = 4
+        self.input_size = (640, 640)  # (height, width)
+        # Actual multiscale ranges: [640 - 5 * 32, 640 + 5 * 32].
+        # To disable multiscale training, set the value to 0.
+        self.multiscale_range = 5
+        # You can uncomment this line to specify a multiscale range
+        # self.random_size = (14, 26)
+        # dir of dataset images, if data_dir is None, this project will use `datasets` dir
+        self.data_dir = None
+        # name of annotation file for training
+        self.train_ann = "instances_train2017.json"
+        # name of annotation file for evaluation
+        self.val_ann = "instances_val2017.json"
+        # name of annotation file for testing
+        self.test_ann = "instances_test2017.json"
+
+        # --------------- transform config ----------------- #
+        # prob of applying mosaic aug
+        self.mosaic_prob = 1.0
+        # prob of applying mixup aug
+        self.mixup_prob = 1.0
+        # prob of applying hsv aug
+        self.hsv_prob = 1.0
+        # prob of applying flip aug
+        self.flip_prob = 0.5
+        # rotation angle range, for example, if set to 2, the true range is (-2, 2)
+        self.degrees = 10.0
+        # translate range, for example, if set to 0.1, the true range is (-0.1, 0.1)
+        self.translate = 0.1
+        self.mosaic_scale = (0.1, 2)
+        # apply mixup aug or not
+        self.enable_mixup = True
+        self.mixup_scale = (0.5, 1.5)
+        # shear angle range, for example, if set to 2, the true range is (-2, 2)
+        self.shear = 2.0
+
+        # --------------  training config --------------------- #
+        # epoch number used for warmup
+        self.warmup_epochs = 5
+        # max training epoch
+        self.max_epoch = 300
+        # minimum learning rate during warmup
+        self.warmup_lr = 0
+        self.min_lr_ratio = 0.05
+        # learning rate for one image. During training, lr will multiply batchsize.
+        self.basic_lr_per_img = 0.01 / 64.0
+        # name of LRScheduler
+        self.scheduler = "yoloxwarmcos"
+        # last #epoch to close augmention like mosaic
+        self.no_aug_epochs = 15
+        # apply EMA during training
+        self.ema = True
+
+        # weight decay of optimizer
+        self.weight_decay = 5e-4
+        # momentum of optimizer
+        self.momentum = 0.9
+        # log period in iter, for example,
+        # if set to 1, user could see log every iteration.
+        self.print_interval = 10
+        # eval period in epoch, for example,
+        # if set to 1, model will be evaluate after every epoch.
+        self.eval_interval = 10
+        # save history checkpoint or not.
+        # If set to False, yolox will only save latest and best ckpt.
+        self.save_history_ckpt = True
+        # name of experiment
+        self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0]
+
+        # -----------------  testing config ------------------ #
+        # output image size during evaluation/test
+        self.test_size = (640, 640)
+        # confidence threshold during evaluation/test,
+        # boxes whose scores are less than test_conf will be filtered
+        self.test_conf = 0.01
+        # nms threshold
+        self.nmsthre = 0.65
+
+    def get_model(self):
+        from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead
+
+        def init_yolo(M):
+            for m in M.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eps = 1e-3
+                    m.momentum = 0.03
+
+        if getattr(self, "model", None) is None:
+            in_channels = [256, 512, 1024]
+            backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act)
+            head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act)
+            self.model = YOLOX(backbone, head)
+
+        self.model.apply(init_yolo)
+        self.model.head.initialize_biases(1e-2)
+        self.model.train()
+        return self.model
+
+    def get_dataset(self, cache: bool = False, cache_type: str = "ram"):
+        """
+        Get dataset according to cache and cache_type parameters.
+        Args:
+            cache (bool): Whether to cache imgs to ram or disk.
+            cache_type (str, optional): Defaults to "ram".
+                "ram" : Caching imgs to ram for fast training.
+                "disk": Caching imgs to disk for fast training.
+        """
+        from yolox.data import COCODataset, TrainTransform
+
+        return COCODataset(
+            data_dir=self.data_dir,
+            json_file=self.train_ann,
+            img_size=self.input_size,
+            preproc=TrainTransform(
+                max_labels=50,
+                flip_prob=self.flip_prob,
+                hsv_prob=self.hsv_prob
+            ),
+            cache=cache,
+            cache_type=cache_type,
+        )
+
+    def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img: str = None):
+        """
+        Get dataloader according to cache_img parameter.
+        Args:
+            no_aug (bool, optional): Whether to turn off mosaic data enhancement. Defaults to False.
+            cache_img (str, optional): cache_img is equivalent to cache_type. Defaults to None.
+                "ram" : Caching imgs to ram for fast training.
+                "disk": Caching imgs to disk for fast training.
+                None: Do not use cache, in this case cache_data is also None.
+        """
+        from yolox.data import (
+            TrainTransform,
+            YoloBatchSampler,
+            DataLoader,
+            InfiniteSampler,
+            MosaicDetection,
+            worker_init_reset_seed,
+        )
+        from yolox.utils import wait_for_the_master
+
+        # if cache is True, we will create self.dataset before launch
+        # else we will create self.dataset after launch
+        if self.dataset is None:
+            with wait_for_the_master():
+                assert cache_img is None, \
+                    "cache_img must be None if you didn't create self.dataset before launch"
+                self.dataset = self.get_dataset(cache=False, cache_type=cache_img)
+
+        self.dataset = MosaicDetection(
+            dataset=self.dataset,
+            mosaic=not no_aug,
+            img_size=self.input_size,
+            preproc=TrainTransform(
+                max_labels=120,
+                flip_prob=self.flip_prob,
+                hsv_prob=self.hsv_prob),
+            degrees=self.degrees,
+            translate=self.translate,
+            mosaic_scale=self.mosaic_scale,
+            mixup_scale=self.mixup_scale,
+            shear=self.shear,
+            enable_mixup=self.enable_mixup,
+            mosaic_prob=self.mosaic_prob,
+            mixup_prob=self.mixup_prob,
+        )
+
+        if is_distributed:
+            batch_size = batch_size // dist.get_world_size()
+
+        sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0)
+
+        batch_sampler = YoloBatchSampler(
+            sampler=sampler,
+            batch_size=batch_size,
+            drop_last=False,
+            mosaic=not no_aug,
+        )
+
+        dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True}
+        dataloader_kwargs["batch_sampler"] = batch_sampler
+
+        # Make sure each process has different random seed, especially for 'fork' method.
+        # Check https://github.com/pytorch/pytorch/issues/63311 for more details.
+        dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed
+
+        train_loader = DataLoader(self.dataset, **dataloader_kwargs)
+
+        return train_loader
+
+    def random_resize(self, data_loader, epoch, rank, is_distributed):
+        tensor = torch.LongTensor(2).cuda()
+
+        if rank == 0:
+            size_factor = self.input_size[1] * 1.0 / self.input_size[0]
+            if not hasattr(self, 'random_size'):
+                min_size = int(self.input_size[0] / 32) - self.multiscale_range
+                max_size = int(self.input_size[0] / 32) + self.multiscale_range
+                self.random_size = (min_size, max_size)
+            size = random.randint(*self.random_size)
+            size = (int(32 * size), 32 * int(size * size_factor))
+            tensor[0] = size[0]
+            tensor[1] = size[1]
+
+        if is_distributed:
+            dist.barrier()
+            dist.broadcast(tensor, 0)
+
+        input_size = (tensor[0].item(), tensor[1].item())
+        return input_size
+
+    def preprocess(self, inputs, targets, tsize):
+        scale_y = tsize[0] / self.input_size[0]
+        scale_x = tsize[1] / self.input_size[1]
+        if scale_x != 1 or scale_y != 1:
+            inputs = nn.functional.interpolate(
+                inputs, size=tsize, mode="bilinear", align_corners=False
+            )
+            targets[..., 1::2] = targets[..., 1::2] * scale_x
+            targets[..., 2::2] = targets[..., 2::2] * scale_y
+        return inputs, targets
+
+    def get_optimizer(self, batch_size):
+        if "optimizer" not in self.__dict__:
+            if self.warmup_epochs > 0:
+                lr = self.warmup_lr
+            else:
+                lr = self.basic_lr_per_img * batch_size
+
+            pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
+
+            for k, v in self.model.named_modules():
+                if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
+                    pg2.append(v.bias)  # biases
+                if isinstance(v, nn.BatchNorm2d) or "bn" in k:
+                    pg0.append(v.weight)  # no decay
+                elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
+                    pg1.append(v.weight)  # apply decay
+
+            optimizer = torch.optim.SGD(
+                pg0, lr=lr, momentum=self.momentum, nesterov=True
+            )
+            optimizer.add_param_group(
+                {"params": pg1, "weight_decay": self.weight_decay}
+            )  # add pg1 with weight_decay
+            optimizer.add_param_group({"params": pg2})
+            self.optimizer = optimizer
+
+        return self.optimizer
+
+    def get_lr_scheduler(self, lr, iters_per_epoch):
+        from yolox.utils import LRScheduler
+
+        scheduler = LRScheduler(
+            self.scheduler,
+            lr,
+            iters_per_epoch,
+            self.max_epoch,
+            warmup_epochs=self.warmup_epochs,
+            warmup_lr_start=self.warmup_lr,
+            no_aug_epochs=self.no_aug_epochs,
+            min_lr_ratio=self.min_lr_ratio,
+        )
+        return scheduler
+
+    def get_eval_dataset(self, **kwargs):
+        from yolox.data import COCODataset, ValTransform
+        testdev = kwargs.get("testdev", False)
+        legacy = kwargs.get("legacy", False)
+
+        return COCODataset(
+            data_dir=self.data_dir,
+            json_file=self.val_ann if not testdev else self.test_ann,
+            name="val2017" if not testdev else "test2017",
+            img_size=self.test_size,
+            preproc=ValTransform(legacy=legacy),
+        )
+
+    def get_eval_loader(self, batch_size, is_distributed, **kwargs):
+        valdataset = self.get_eval_dataset(**kwargs)
+
+        if is_distributed:
+            batch_size = batch_size // dist.get_world_size()
+            sampler = torch.utils.data.distributed.DistributedSampler(
+                valdataset, shuffle=False
+            )
+        else:
+            sampler = torch.utils.data.SequentialSampler(valdataset)
+
+        dataloader_kwargs = {
+            "num_workers": self.data_num_workers,
+            "pin_memory": True,
+            "sampler": sampler,
+        }
+        dataloader_kwargs["batch_size"] = batch_size
+        val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs)
+
+        return val_loader
+
+    def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False):
+        from yolox.evaluators import COCOEvaluator
+
+        return COCOEvaluator(
+            dataloader=self.get_eval_loader(batch_size, is_distributed,
+                                            testdev=testdev, legacy=legacy),
+            img_size=self.test_size,
+            confthre=self.test_conf,
+            nmsthre=self.nmsthre,
+            num_classes=self.num_classes,
+            testdev=testdev,
+        )
+
+    def get_trainer(self, args):
+        from yolox.core import Trainer
+        trainer = Trainer(self, args)
+        # NOTE: trainer shouldn't be an attribute of exp object
+        return trainer
+
+    def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
+        return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)
+
+
+def check_exp_value(exp: Exp):
+    h, w = exp.input_size
+    assert h % 32 == 0 and w % 32 == 0, "input size must be multiples of 32"
diff --git a/multimodal/YOLOX/yolox/layers/__init__.py b/multimodal/YOLOX/yolox/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc9cf513818289977d5938e11efdc8d931032fae
--- /dev/null
+++ b/multimodal/YOLOX/yolox/layers/__init__.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+# import torch first to make jit op work without `ImportError of libc10.so`
+import torch  # noqa
+
+from .jit_ops import FastCOCOEvalOp, JitOp
+
+try:
+    from .fast_coco_eval_api import COCOeval_opt
+except ImportError:  #  exception will be raised when users build yolox from source
+    pass
diff --git a/multimodal/YOLOX/yolox/layers/cocoeval/cocoeval.cpp b/multimodal/YOLOX/yolox/layers/cocoeval/cocoeval.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2e63bc9952918060f55999ec100b283d83616b46
--- /dev/null
+++ b/multimodal/YOLOX/yolox/layers/cocoeval/cocoeval.cpp
@@ -0,0 +1,502 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+#include "cocoeval.h"
+#include <time.h>
+#include <algorithm>
+#include <cstdint>
+#include <numeric>
+
+using namespace pybind11::literals;
+
+namespace COCOeval {
+
+// Sort detections from highest score to lowest, such that
+// detection_instances[detection_sorted_indices[t]] >=
+// detection_instances[detection_sorted_indices[t+1]].  Use stable_sort to match
+// original COCO API
+void SortInstancesByDetectionScore(
+    const std::vector<InstanceAnnotation>& detection_instances,
+    std::vector<uint64_t>* detection_sorted_indices) {
+  detection_sorted_indices->resize(detection_instances.size());
+  std::iota(
+      detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
+  std::stable_sort(
+      detection_sorted_indices->begin(),
+      detection_sorted_indices->end(),
+      [&detection_instances](size_t j1, size_t j2) {
+        return detection_instances[j1].score > detection_instances[j2].score;
+      });
+}
+
+// Partition the ground truth objects based on whether or not to ignore them
+// based on area
+void SortInstancesByIgnore(
+    const std::array<double, 2>& area_range,
+    const std::vector<InstanceAnnotation>& ground_truth_instances,
+    std::vector<uint64_t>* ground_truth_sorted_indices,
+    std::vector<bool>* ignores) {
+  ignores->clear();
+  ignores->reserve(ground_truth_instances.size());
+  for (auto o : ground_truth_instances) {
+    ignores->push_back(
+        o.ignore || o.area < area_range[0] || o.area > area_range[1]);
+  }
+
+  ground_truth_sorted_indices->resize(ground_truth_instances.size());
+  std::iota(
+      ground_truth_sorted_indices->begin(),
+      ground_truth_sorted_indices->end(),
+      0);
+  std::stable_sort(
+      ground_truth_sorted_indices->begin(),
+      ground_truth_sorted_indices->end(),
+      [&ignores](size_t j1, size_t j2) {
+        return (int)(*ignores)[j1] < (int)(*ignores)[j2];
+      });
+}
+
+// For each IOU threshold, greedily match each detected instance to a ground
+// truth instance (if possible) and store the results
+void MatchDetectionsToGroundTruth(
+    const std::vector<InstanceAnnotation>& detection_instances,
+    const std::vector<uint64_t>& detection_sorted_indices,
+    const std::vector<InstanceAnnotation>& ground_truth_instances,
+    const std::vector<uint64_t>& ground_truth_sorted_indices,
+    const std::vector<bool>& ignores,
+    const std::vector<std::vector<double>>& ious,
+    const std::vector<double>& iou_thresholds,
+    const std::array<double, 2>& area_range,
+    ImageEvaluation* results) {
+  // Initialize memory to store return data matches and ignore
+  const int num_iou_thresholds = iou_thresholds.size();
+  const int num_ground_truth = ground_truth_sorted_indices.size();
+  const int num_detections = detection_sorted_indices.size();
+  std::vector<uint64_t> ground_truth_matches(
+      num_iou_thresholds * num_ground_truth, 0);
+  std::vector<uint64_t>& detection_matches = results->detection_matches;
+  std::vector<bool>& detection_ignores = results->detection_ignores;
+  std::vector<bool>& ground_truth_ignores = results->ground_truth_ignores;
+  detection_matches.resize(num_iou_thresholds * num_detections, 0);
+  detection_ignores.resize(num_iou_thresholds * num_detections, false);
+  ground_truth_ignores.resize(num_ground_truth);
+  for (auto g = 0; g < num_ground_truth; ++g) {
+    ground_truth_ignores[g] = ignores[ground_truth_sorted_indices[g]];
+  }
+
+  for (auto t = 0; t < num_iou_thresholds; ++t) {
+    for (auto d = 0; d < num_detections; ++d) {
+      // information about best match so far (match=-1 -> unmatched)
+      double best_iou = std::min(iou_thresholds[t], 1 - 1e-10);
+      int match = -1;
+      for (auto g = 0; g < num_ground_truth; ++g) {
+        // if this ground truth instance is already matched and not a
+        // crowd, it cannot be matched to another detection
+        if (ground_truth_matches[t * num_ground_truth + g] > 0 &&
+            !ground_truth_instances[ground_truth_sorted_indices[g]].is_crowd) {
+          continue;
+        }
+
+        // if detected instance matched to a regular ground truth
+        // instance, we can break on the first ground truth instance
+        // tagged as ignore (because they are sorted by the ignore tag)
+        if (match >= 0 && !ground_truth_ignores[match] &&
+            ground_truth_ignores[g]) {
+          break;
+        }
+
+        // if IOU overlap is the best so far, store the match appropriately
+        if (ious[d][ground_truth_sorted_indices[g]] >= best_iou) {
+          best_iou = ious[d][ground_truth_sorted_indices[g]];
+          match = g;
+        }
+      }
+      // if match was made, store id of match for both detection and
+      // ground truth
+      if (match >= 0) {
+        detection_ignores[t * num_detections + d] = ground_truth_ignores[match];
+        detection_matches[t * num_detections + d] =
+            ground_truth_instances[ground_truth_sorted_indices[match]].id;
+        ground_truth_matches[t * num_ground_truth + match] =
+            detection_instances[detection_sorted_indices[d]].id;
+      }
+
+      // set unmatched detections outside of area range to ignore
+      const InstanceAnnotation& detection =
+          detection_instances[detection_sorted_indices[d]];
+      detection_ignores[t * num_detections + d] =
+          detection_ignores[t * num_detections + d] ||
+          (detection_matches[t * num_detections + d] == 0 &&
+           (detection.area < area_range[0] || detection.area > area_range[1]));
+    }
+  }
+
+  // store detection score results
+  results->detection_scores.resize(detection_sorted_indices.size());
+  for (size_t d = 0; d < detection_sorted_indices.size(); ++d) {
+    results->detection_scores[d] =
+        detection_instances[detection_sorted_indices[d]].score;
+  }
+}
+
+std::vector<ImageEvaluation> EvaluateImages(
+    const std::vector<std::array<double, 2>>& area_ranges,
+    int max_detections,
+    const std::vector<double>& iou_thresholds,
+    const ImageCategoryInstances<std::vector<double>>& image_category_ious,
+    const ImageCategoryInstances<InstanceAnnotation>&
+        image_category_ground_truth_instances,
+    const ImageCategoryInstances<InstanceAnnotation>&
+        image_category_detection_instances) {
+  const int num_area_ranges = area_ranges.size();
+  const int num_images = image_category_ground_truth_instances.size();
+  const int num_categories =
+      image_category_ious.size() > 0 ? image_category_ious[0].size() : 0;
+  std::vector<uint64_t> detection_sorted_indices;
+  std::vector<uint64_t> ground_truth_sorted_indices;
+  std::vector<bool> ignores;
+  std::vector<ImageEvaluation> results_all(
+      num_images * num_area_ranges * num_categories);
+
+  // Store results for each image, category, and area range combination. Results
+  // for each IOU threshold are packed into the same ImageEvaluation object
+  for (auto i = 0; i < num_images; ++i) {
+    for (auto c = 0; c < num_categories; ++c) {
+      const std::vector<InstanceAnnotation>& ground_truth_instances =
+          image_category_ground_truth_instances[i][c];
+      const std::vector<InstanceAnnotation>& detection_instances =
+          image_category_detection_instances[i][c];
+
+      SortInstancesByDetectionScore(
+          detection_instances, &detection_sorted_indices);
+      if ((int)detection_sorted_indices.size() > max_detections) {
+        detection_sorted_indices.resize(max_detections);
+      }
+
+      for (size_t a = 0; a < area_ranges.size(); ++a) {
+        SortInstancesByIgnore(
+            area_ranges[a],
+            ground_truth_instances,
+            &ground_truth_sorted_indices,
+            &ignores);
+
+        MatchDetectionsToGroundTruth(
+            detection_instances,
+            detection_sorted_indices,
+            ground_truth_instances,
+            ground_truth_sorted_indices,
+            ignores,
+            image_category_ious[i][c],
+            iou_thresholds,
+            area_ranges[a],
+            &results_all
+                [c * num_area_ranges * num_images + a * num_images + i]);
+      }
+    }
+  }
+
+  return results_all;
+}
+
+// Convert a python list to a vector
+template <typename T>
+std::vector<T> list_to_vec(const py::list& l) {
+  std::vector<T> v(py::len(l));
+  for (int i = 0; i < (int)py::len(l); ++i) {
+    v[i] = l[i].cast<T>();
+  }
+  return v;
+}
+
+// Helper function to Accumulate()
+// Considers the evaluation results applicable to a particular category, area
+// range, and max_detections parameter setting, which begin at
+// evaluations[evaluation_index].  Extracts a sorted list of length n of all
+// applicable detection instances concatenated across all images in the dataset,
+// which are represented by the outputs evaluation_indices, detection_scores,
+// image_detection_indices, and detection_sorted_indices--all of which are
+// length n. evaluation_indices[i] stores the applicable index into
+// evaluations[] for instance i, which has detection score detection_score[i],
+// and is the image_detection_indices[i]'th of the list of detections
+// for the image containing i.  detection_sorted_indices[] defines a sorted
+// permutation of the 3 other outputs
+int BuildSortedDetectionList(
+    const std::vector<ImageEvaluation>& evaluations,
+    const int64_t evaluation_index,
+    const int64_t num_images,
+    const int max_detections,
+    std::vector<uint64_t>* evaluation_indices,
+    std::vector<double>* detection_scores,
+    std::vector<uint64_t>* detection_sorted_indices,
+    std::vector<uint64_t>* image_detection_indices) {
+  assert(evaluations.size() >= evaluation_index + num_images);
+
+  // Extract a list of object instances of the applicable category, area
+  // range, and max detections requirements such that they can be sorted
+  image_detection_indices->clear();
+  evaluation_indices->clear();
+  detection_scores->clear();
+  image_detection_indices->reserve(num_images * max_detections);
+  evaluation_indices->reserve(num_images * max_detections);
+  detection_scores->reserve(num_images * max_detections);
+  int num_valid_ground_truth = 0;
+  for (auto i = 0; i < num_images; ++i) {
+    const ImageEvaluation& evaluation = evaluations[evaluation_index + i];
+
+    for (int d = 0;
+         d < (int)evaluation.detection_scores.size() && d < max_detections;
+         ++d) { // detected instances
+      evaluation_indices->push_back(evaluation_index + i);
+      image_detection_indices->push_back(d);
+      detection_scores->push_back(evaluation.detection_scores[d]);
+    }
+    for (auto ground_truth_ignore : evaluation.ground_truth_ignores) {
+      if (!ground_truth_ignore) {
+        ++num_valid_ground_truth;
+      }
+    }
+  }
+
+  // Sort detections by decreasing score, using stable sort to match
+  // python implementation
+  detection_sorted_indices->resize(detection_scores->size());
+  std::iota(
+      detection_sorted_indices->begin(), detection_sorted_indices->end(), 0);
+  std::stable_sort(
+      detection_sorted_indices->begin(),
+      detection_sorted_indices->end(),
+      [&detection_scores](size_t j1, size_t j2) {
+        return (*detection_scores)[j1] > (*detection_scores)[j2];
+      });
+
+  return num_valid_ground_truth;
+}
+
+// Helper function to Accumulate()
+// Compute a precision recall curve given a sorted list of detected instances
+// encoded in evaluations, evaluation_indices, detection_scores,
+// detection_sorted_indices, image_detection_indices (see
+// BuildSortedDetectionList()). Using vectors precisions and recalls
+// and temporary storage, output the results into precisions_out, recalls_out,
+// and scores_out, which are large buffers containing many precion/recall curves
+// for all possible parameter settings, with precisions_out_index and
+// recalls_out_index defining the applicable indices to store results.
+void ComputePrecisionRecallCurve(
+    const int64_t precisions_out_index,
+    const int64_t precisions_out_stride,
+    const int64_t recalls_out_index,
+    const std::vector<double>& recall_thresholds,
+    const int iou_threshold_index,
+    const int num_iou_thresholds,
+    const int num_valid_ground_truth,
+    const std::vector<ImageEvaluation>& evaluations,
+    const std::vector<uint64_t>& evaluation_indices,
+    const std::vector<double>& detection_scores,
+    const std::vector<uint64_t>& detection_sorted_indices,
+    const std::vector<uint64_t>& image_detection_indices,
+    std::vector<double>* precisions,
+    std::vector<double>* recalls,
+    std::vector<double>* precisions_out,
+    std::vector<double>* scores_out,
+    std::vector<double>* recalls_out) {
+  assert(recalls_out->size() > recalls_out_index);
+
+  // Compute precision/recall for each instance in the sorted list of detections
+  int64_t true_positives_sum = 0, false_positives_sum = 0;
+  precisions->clear();
+  recalls->clear();
+  precisions->reserve(detection_sorted_indices.size());
+  recalls->reserve(detection_sorted_indices.size());
+  assert(!evaluations.empty() || detection_sorted_indices.empty());
+  for (auto detection_sorted_index : detection_sorted_indices) {
+    const ImageEvaluation& evaluation =
+        evaluations[evaluation_indices[detection_sorted_index]];
+    const auto num_detections =
+        evaluation.detection_matches.size() / num_iou_thresholds;
+    const auto detection_index = iou_threshold_index * num_detections +
+        image_detection_indices[detection_sorted_index];
+    assert(evaluation.detection_matches.size() > detection_index);
+    assert(evaluation.detection_ignores.size() > detection_index);
+    const int64_t detection_match =
+        evaluation.detection_matches[detection_index];
+    const bool detection_ignores =
+        evaluation.detection_ignores[detection_index];
+    const auto true_positive = detection_match > 0 && !detection_ignores;
+    const auto false_positive = detection_match == 0 && !detection_ignores;
+    if (true_positive) {
+      ++true_positives_sum;
+    }
+    if (false_positive) {
+      ++false_positives_sum;
+    }
+
+    const double recall =
+        static_cast<double>(true_positives_sum) / num_valid_ground_truth;
+    recalls->push_back(recall);
+    const int64_t num_valid_detections =
+        true_positives_sum + false_positives_sum;
+    const double precision = num_valid_detections > 0
+        ? static_cast<double>(true_positives_sum) / num_valid_detections
+        : 0.0;
+    precisions->push_back(precision);
+  }
+
+  (*recalls_out)[recalls_out_index] = !recalls->empty() ? recalls->back() : 0;
+
+  for (int64_t i = static_cast<int64_t>(precisions->size()) - 1; i > 0; --i) {
+    if ((*precisions)[i] > (*precisions)[i - 1]) {
+      (*precisions)[i - 1] = (*precisions)[i];
+    }
+  }
+
+  // Sample the per instance precision/recall list at each recall threshold
+  for (size_t r = 0; r < recall_thresholds.size(); ++r) {
+    // first index in recalls >= recall_thresholds[r]
+    std::vector<double>::iterator low = std::lower_bound(
+        recalls->begin(), recalls->end(), recall_thresholds[r]);
+    size_t precisions_index = low - recalls->begin();
+
+    const auto results_ind = precisions_out_index + r * precisions_out_stride;
+    assert(results_ind < precisions_out->size());
+    assert(results_ind < scores_out->size());
+    if (precisions_index < precisions->size()) {
+      (*precisions_out)[results_ind] = (*precisions)[precisions_index];
+      (*scores_out)[results_ind] =
+          detection_scores[detection_sorted_indices[precisions_index]];
+    } else {
+      (*precisions_out)[results_ind] = 0;
+      (*scores_out)[results_ind] = 0;
+    }
+  }
+}
+py::dict Accumulate(
+    const py::object& params,
+    const std::vector<ImageEvaluation>& evaluations) {
+  const std::vector<double> recall_thresholds =
+      list_to_vec<double>(params.attr("recThrs"));
+  const std::vector<int> max_detections =
+      list_to_vec<int>(params.attr("maxDets"));
+  const int num_iou_thresholds = py::len(params.attr("iouThrs"));
+  const int num_recall_thresholds = py::len(params.attr("recThrs"));
+  const int num_categories = params.attr("useCats").cast<int>() == 1
+      ? py::len(params.attr("catIds"))
+      : 1;
+  const int num_area_ranges = py::len(params.attr("areaRng"));
+  const int num_max_detections = py::len(params.attr("maxDets"));
+  const int num_images = py::len(params.attr("imgIds"));
+
+  std::vector<double> precisions_out(
+      num_iou_thresholds * num_recall_thresholds * num_categories *
+          num_area_ranges * num_max_detections,
+      -1);
+  std::vector<double> recalls_out(
+      num_iou_thresholds * num_categories * num_area_ranges *
+          num_max_detections,
+      -1);
+  std::vector<double> scores_out(
+      num_iou_thresholds * num_recall_thresholds * num_categories *
+          num_area_ranges * num_max_detections,
+      -1);
+
+  // Consider the list of all detected instances in the entire dataset in one
+  // large list.  evaluation_indices, detection_scores,
+  // image_detection_indices, and detection_sorted_indices all have the same
+  // length as this list, such that each entry corresponds to one detected
+  // instance
+  std::vector<uint64_t> evaluation_indices; // indices into evaluations[]
+  std::vector<double> detection_scores; // detection scores of each instance
+  std::vector<uint64_t> detection_sorted_indices; // sorted indices of all
+                                                  // instances in the dataset
+  std::vector<uint64_t>
+      image_detection_indices; // indices into the list of detected instances in
+                               // the same image as each instance
+  std::vector<double> precisions, recalls;
+
+  for (auto c = 0; c < num_categories; ++c) {
+    for (auto a = 0; a < num_area_ranges; ++a) {
+      for (auto m = 0; m < num_max_detections; ++m) {
+        // The COCO PythonAPI assumes evaluations[] (the return value of
+        // COCOeval::EvaluateImages() is one long list storing results for each
+        // combination of category, area range, and image id, with categories in
+        // the outermost loop and images in the innermost loop.
+        const int64_t evaluations_index =
+            c * num_area_ranges * num_images + a * num_images;
+        int num_valid_ground_truth = BuildSortedDetectionList(
+            evaluations,
+            evaluations_index,
+            num_images,
+            max_detections[m],
+            &evaluation_indices,
+            &detection_scores,
+            &detection_sorted_indices,
+            &image_detection_indices);
+
+        if (num_valid_ground_truth == 0) {
+          continue;
+        }
+
+        for (auto t = 0; t < num_iou_thresholds; ++t) {
+          // recalls_out is a flattened vectors representing a
+          // num_iou_thresholds X num_categories X num_area_ranges X
+          // num_max_detections matrix
+          const int64_t recalls_out_index =
+              t * num_categories * num_area_ranges * num_max_detections +
+              c * num_area_ranges * num_max_detections +
+              a * num_max_detections + m;
+
+          // precisions_out and scores_out are flattened vectors
+          // representing a num_iou_thresholds X num_recall_thresholds X
+          // num_categories X num_area_ranges X num_max_detections matrix
+          const int64_t precisions_out_stride =
+              num_categories * num_area_ranges * num_max_detections;
+          const int64_t precisions_out_index = t * num_recall_thresholds *
+                  num_categories * num_area_ranges * num_max_detections +
+              c * num_area_ranges * num_max_detections +
+              a * num_max_detections + m;
+
+          ComputePrecisionRecallCurve(
+              precisions_out_index,
+              precisions_out_stride,
+              recalls_out_index,
+              recall_thresholds,
+              t,
+              num_iou_thresholds,
+              num_valid_ground_truth,
+              evaluations,
+              evaluation_indices,
+              detection_scores,
+              detection_sorted_indices,
+              image_detection_indices,
+              &precisions,
+              &recalls,
+              &precisions_out,
+              &scores_out,
+              &recalls_out);
+        }
+      }
+    }
+  }
+
+  time_t rawtime;
+  struct tm local_time;
+  std::array<char, 200> buffer;
+  time(&rawtime);
+#ifdef _WIN32
+  localtime_s(&local_time, &rawtime);
+#else
+  localtime_r(&rawtime, &local_time);
+#endif
+  strftime(
+      buffer.data(), 200, "%Y-%m-%d %H:%num_max_detections:%S", &local_time);
+  return py::dict(
+      "params"_a = params,
+      "counts"_a = std::vector<int64_t>({num_iou_thresholds,
+                                         num_recall_thresholds,
+                                         num_categories,
+                                         num_area_ranges,
+                                         num_max_detections}),
+      "date"_a = buffer,
+      "precision"_a = precisions_out,
+      "recall"_a = recalls_out,
+      "scores"_a = scores_out);
+}
+
+} // namespace COCOeval
diff --git a/multimodal/YOLOX/yolox/layers/cocoeval/cocoeval.h b/multimodal/YOLOX/yolox/layers/cocoeval/cocoeval.h
new file mode 100644
index 0000000000000000000000000000000000000000..dbf5aab4b8303b8e199f10e1ecf2f634ca29cb42
--- /dev/null
+++ b/multimodal/YOLOX/yolox/layers/cocoeval/cocoeval.h
@@ -0,0 +1,98 @@
+// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+#pragma once
+
+#include <pybind11/numpy.h>
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+#include <pybind11/stl_bind.h>
+#include <vector>
+
+namespace py = pybind11;
+
+namespace COCOeval {
+
+// Annotation data for a single object instance in an image
+struct InstanceAnnotation {
+  InstanceAnnotation(
+      uint64_t id,
+      double score,
+      double area,
+      bool is_crowd,
+      bool ignore)
+      : id{id}, score{score}, area{area}, is_crowd{is_crowd}, ignore{ignore} {}
+  uint64_t id;
+  double score = 0.;
+  double area = 0.;
+  bool is_crowd = false;
+  bool ignore = false;
+};
+
+// Stores intermediate results for evaluating detection results for a single
+// image that has D detected instances and G ground truth instances. This stores
+// matches between detected and ground truth instances
+struct ImageEvaluation {
+  // For each of the D detected instances, the id of the matched ground truth
+  // instance, or 0 if unmatched
+  std::vector<uint64_t> detection_matches;
+
+  // The detection score of each of the D detected instances
+  std::vector<double> detection_scores;
+
+  // Marks whether or not each of G instances was ignored from evaluation (e.g.,
+  // because it's outside area_range)
+  std::vector<bool> ground_truth_ignores;
+
+  // Marks whether or not each of D instances was ignored from evaluation (e.g.,
+  // because it's outside aRng)
+  std::vector<bool> detection_ignores;
+};
+
+template <class T>
+using ImageCategoryInstances = std::vector<std::vector<std::vector<T>>>;
+
+// C++ implementation of COCO API cocoeval.py::COCOeval.evaluateImg().  For each
+// combination of image, category, area range settings, and IOU thresholds to
+// evaluate, it matches detected instances to ground truth instances and stores
+// the results into a vector of ImageEvaluation results, which will be
+// interpreted by the COCOeval::Accumulate() function to produce precion-recall
+// curves.  The parameters of nested vectors have the following semantics:
+//   image_category_ious[i][c][d][g] is the intersection over union of the d'th
+//     detected instance and g'th ground truth instance of
+//     category category_ids[c] in image image_ids[i]
+//   image_category_ground_truth_instances[i][c] is a vector of ground truth
+//     instances in image image_ids[i] of category category_ids[c]
+//   image_category_detection_instances[i][c] is a vector of detected
+//     instances in image image_ids[i] of category category_ids[c]
+std::vector<ImageEvaluation> EvaluateImages(
+    const std::vector<std::array<double, 2>>& area_ranges, // vector of 2-tuples
+    int max_detections,
+    const std::vector<double>& iou_thresholds,
+    const ImageCategoryInstances<std::vector<double>>& image_category_ious,
+    const ImageCategoryInstances<InstanceAnnotation>&
+        image_category_ground_truth_instances,
+    const ImageCategoryInstances<InstanceAnnotation>&
+        image_category_detection_instances);
+
+// C++ implementation of COCOeval.accumulate(), which generates precision
+// recall curves for each set of category, IOU threshold, detection area range,
+// and max number of detections parameters.  It is assumed that the parameter
+// evaluations is the return value of the functon COCOeval::EvaluateImages(),
+// which was called with the same parameter settings params
+py::dict Accumulate(
+    const py::object& params,
+    const std::vector<ImageEvaluation>& evalutations);
+
+} // namespace COCOeval
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+    m.def("COCOevalAccumulate", &COCOeval::Accumulate, "COCOeval::Accumulate");
+    m.def(
+        "COCOevalEvaluateImages",
+        &COCOeval::EvaluateImages,
+        "COCOeval::EvaluateImages");
+    pybind11::class_<COCOeval::InstanceAnnotation>(m, "InstanceAnnotation")
+        .def(pybind11::init<uint64_t, double, double, bool, bool>());
+    pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation")
+        .def(pybind11::init<>());
+}
diff --git a/multimodal/YOLOX/yolox/layers/fast_coco_eval_api.py b/multimodal/YOLOX/yolox/layers/fast_coco_eval_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f3aeb5517077718331074c3795ed2d10b4954bc
--- /dev/null
+++ b/multimodal/YOLOX/yolox/layers/fast_coco_eval_api.py
@@ -0,0 +1,151 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# This file comes from
+# https://github.com/facebookresearch/detectron2/blob/master/detectron2/evaluation/fast_eval_api.py
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import copy
+import time
+
+import numpy as np
+from pycocotools.cocoeval import COCOeval
+
+from .jit_ops import FastCOCOEvalOp
+
+
+class COCOeval_opt(COCOeval):
+    """
+    This is a slightly modified version of the original COCO API, where the functions evaluateImg()
+    and accumulate() are implemented in C++ to speedup evaluation
+    """
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.module = FastCOCOEvalOp().load()
+
+    def evaluate(self):
+        """
+        Run per image evaluation on given images and store results in self.evalImgs_cpp, a
+        datastructure that isn't readable from Python but is used by a c++ implementation of
+        accumulate().  Unlike the original COCO PythonAPI, we don't populate the datastructure
+        self.evalImgs because this datastructure is a computational bottleneck.
+        :return: None
+        """
+        tic = time.time()
+
+        print("Running per image evaluation...")
+        p = self.params
+        # add backward compatibility if useSegm is specified in params
+        if p.useSegm is not None:
+            p.iouType = "segm" if p.useSegm == 1 else "bbox"
+            print(
+                "useSegm (deprecated) is not None. Running {} evaluation".format(
+                    p.iouType
+                )
+            )
+        print("Evaluate annotation type *{}*".format(p.iouType))
+        p.imgIds = list(np.unique(p.imgIds))
+        if p.useCats:
+            p.catIds = list(np.unique(p.catIds))
+        p.maxDets = sorted(p.maxDets)
+        self.params = p
+
+        self._prepare()
+
+        # loop through images, area range, max detection number
+        catIds = p.catIds if p.useCats else [-1]
+
+        if p.iouType == "segm" or p.iouType == "bbox":
+            computeIoU = self.computeIoU
+        elif p.iouType == "keypoints":
+            computeIoU = self.computeOks
+        self.ious = {
+            (imgId, catId): computeIoU(imgId, catId)
+            for imgId in p.imgIds
+            for catId in catIds
+        }
+
+        maxDet = p.maxDets[-1]
+
+        # <<<< Beginning of code differences with original COCO API
+        def convert_instances_to_cpp(instances, is_det=False):
+            # Convert annotations for a list of instances in an image to a format that's fast
+            # to access in C++
+            instances_cpp = []
+            for instance in instances:
+                instance_cpp = self.module.InstanceAnnotation(
+                    int(instance["id"]),
+                    instance["score"] if is_det else instance.get("score", 0.0),
+                    instance["area"],
+                    bool(instance.get("iscrowd", 0)),
+                    bool(instance.get("ignore", 0)),
+                )
+                instances_cpp.append(instance_cpp)
+            return instances_cpp
+
+        # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++
+        ground_truth_instances = [
+            [convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds]
+            for imgId in p.imgIds
+        ]
+        detected_instances = [
+            [
+                convert_instances_to_cpp(self._dts[imgId, catId], is_det=True)
+                for catId in p.catIds
+            ]
+            for imgId in p.imgIds
+        ]
+        ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds]
+
+        if not p.useCats:
+            # For each image, flatten per-category lists into a single list
+            ground_truth_instances = [
+                [[o for c in i for o in c]] for i in ground_truth_instances
+            ]
+            detected_instances = [
+                [[o for c in i for o in c]] for i in detected_instances
+            ]
+
+        # Call C++ implementation of self.evaluateImgs()
+        self._evalImgs_cpp = self.module.COCOevalEvaluateImages(
+            p.areaRng,
+            maxDet,
+            p.iouThrs,
+            ious,
+            ground_truth_instances,
+            detected_instances,
+        )
+        self._evalImgs = None
+
+        self._paramsEval = copy.deepcopy(self.params)
+        toc = time.time()
+        print("COCOeval_opt.evaluate() finished in {:0.2f} seconds.".format(toc - tic))
+        # >>>> End of code differences with original COCO API
+
+    def accumulate(self):
+        """
+        Accumulate per image evaluation results and store the result in self.eval.  Does not
+        support changing parameter settings from those used by self.evaluate()
+        """
+        print("Accumulating evaluation results...")
+        tic = time.time()
+        if not hasattr(self, "_evalImgs_cpp"):
+            print("Please run evaluate() first")
+
+        self.eval = self.module.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp)
+
+        # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections
+        self.eval["recall"] = np.array(self.eval["recall"]).reshape(
+            self.eval["counts"][:1] + self.eval["counts"][2:]
+        )
+
+        # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X
+        # num_area_ranges X num_max_detections
+        self.eval["precision"] = np.array(self.eval["precision"]).reshape(
+            self.eval["counts"]
+        )
+        self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"])
+        toc = time.time()
+        print(
+            "COCOeval_opt.accumulate() finished in {:0.2f} seconds.".format(toc - tic)
+        )
diff --git a/multimodal/YOLOX/yolox/layers/jit_ops.py b/multimodal/YOLOX/yolox/layers/jit_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fdac4de2b2cedbf523a887ce7564cbc6c372a28
--- /dev/null
+++ b/multimodal/YOLOX/yolox/layers/jit_ops.py
@@ -0,0 +1,138 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii, Inc. and its affiliates. All Rights Reserved
+
+import glob
+import importlib
+import os
+import sys
+import time
+from typing import List
+
+__all__ = ["JitOp", "FastCOCOEvalOp"]
+
+
+class JitOp:
+    """
+    Just-in-time compilation of ops.
+
+    Some code of `JitOp` is inspired by `deepspeed.op_builder`,
+    check the following link for more details:
+    https://github.com/microsoft/DeepSpeed/blob/master/op_builder/builder.py
+    """
+
+    def __init__(self, name):
+        self.name = name
+
+    def absolute_name(self) -> str:
+        """Get absolute build path for cases where the op is pre-installed."""
+        pass
+
+    def sources(self) -> List:
+        """Get path list of source files of op.
+
+        NOTE: the path should be elative to root of package during building,
+            Otherwise, exception will be raised when building package.
+            However, for runtime building, path will be absolute.
+        """
+        pass
+
+    def include_dirs(self) -> List:
+        """
+        Get list of include paths, relative to root of package.
+
+        NOTE: the path should be elative to root of package.
+            Otherwise, exception will be raised when building package.
+        """
+        return []
+
+    def define_macros(self) -> List:
+        """Get list of macros to define for op"""
+        return []
+
+    def cxx_args(self) -> List:
+        """Get optional list of compiler flags to forward"""
+        args = ["-O2"] if sys.platform == "win32" else ["-O3", "-std=c++14", "-g", "-Wno-reorder"]
+        return args
+
+    def nvcc_args(self) -> List:
+        """Get optional list of compiler flags to forward to nvcc when building CUDA sources"""
+        args = [
+            "-O3", "--use_fast_math",
+            "-std=c++17" if sys.platform == "win32" else "-std=c++14",
+            "-U__CUDA_NO_HALF_OPERATORS__",
+            "-U__CUDA_NO_HALF_CONVERSIONS__",
+            "-U__CUDA_NO_HALF2_OPERATORS__",
+        ]
+        return args
+
+    def build_op(self):
+        from torch.utils.cpp_extension import CppExtension
+        return CppExtension(
+            name=self.absolute_name(),
+            sources=self.sources(),
+            include_dirs=self.include_dirs(),
+            define_macros=self.define_macros(),
+            extra_compile_args={
+                "cxx": self.cxx_args(),
+            },
+        )
+
+    def load(self, verbose=True):
+        try:
+            # try to import op from pre-installed package
+            return importlib.import_module(self.absolute_name())
+        except Exception:  # op not compiled, jit load
+            from yolox.utils import wait_for_the_master
+            with wait_for_the_master():  # to avoid race condition
+                return self.jit_load(verbose)
+
+    def jit_load(self, verbose=True):
+        from torch.utils.cpp_extension import load
+        from loguru import logger
+        try:
+            import ninja  # noqa
+        except ImportError:
+            if verbose:
+                logger.warning(
+                    f"Ninja is not installed, fall back to normal installation for {self.name}."
+                )
+
+        build_tik = time.time()
+        # build op and load
+        op_module = load(
+            name=self.name,
+            sources=self.sources(),
+            extra_cflags=self.cxx_args(),
+            extra_cuda_cflags=self.nvcc_args(),
+            verbose=verbose,
+        )
+        build_duration = time.time() - build_tik
+        if verbose:
+            logger.info(f"Load {self.name} op in {build_duration:.3f}s.")
+        return op_module
+
+    def clear_dynamic_library(self):
+        """Remove dynamic libraray files generated by JIT compilation."""
+        module = self.load()
+        os.remove(module.__file__)
+
+
+class FastCOCOEvalOp(JitOp):
+
+    def __init__(self, name="fast_cocoeval"):
+        super().__init__(name=name)
+
+    def absolute_name(self):
+        return f'yolox.layers.{self.name}'
+
+    def sources(self):
+        sources = glob.glob(os.path.join("yolox", "layers", "cocoeval", "*.cpp"))
+        if not sources:  # source will be empty list if the so file is removed after install
+            # use abosolute path to compile
+            import yolox
+            code_path = os.path.join(yolox.__path__[0], "layers", "cocoeval", "*.cpp")
+            sources = glob.glob(code_path)
+        return sources
+
+    def include_dirs(self):
+        return [os.path.join("yolox", "layers", "cocoeval")]
diff --git a/multimodal/YOLOX/yolox/models/__init__.py b/multimodal/YOLOX/yolox/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c74fd3064ac588a7c223018aa31fd2d46f95d062
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/__init__.py
@@ -0,0 +1,11 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+from .build import *
+from .darknet import CSPDarknet, Darknet
+from .losses import IOUloss
+from .yolo_fpn import YOLOFPN
+from .yolo_head import YOLOXHead
+from .yolo_pafpn import YOLOPAFPN
+from .yolox import YOLOX
diff --git a/multimodal/YOLOX/yolox/models/build.py b/multimodal/YOLOX/yolox/models/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..8edc87de9d1dd46b7e693ad15bdbd9ac753bd225
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/build.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+from torch import nn
+from torch.hub import load_state_dict_from_url
+
+__all__ = [
+    "create_yolox_model",
+    "yolox_nano",
+    "yolox_tiny",
+    "yolox_s",
+    "yolox_m",
+    "yolox_l",
+    "yolox_x",
+    "yolov3",
+    "yolox_custom"
+]
+
+_CKPT_ROOT_URL = "https://github.com/Megvii-BaseDetection/YOLOX/releases/download"
+_CKPT_FULL_PATH = {
+    "yolox-nano": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_nano.pth",
+    "yolox-tiny": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_tiny.pth",
+    "yolox-s": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_s.pth",
+    "yolox-m": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_m.pth",
+    "yolox-l": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_l.pth",
+    "yolox-x": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_x.pth",
+    "yolov3": f"{_CKPT_ROOT_URL}/0.1.1rc0/yolox_darknet.pth",
+}
+
+
+def create_yolox_model(name: str, pretrained: bool = True, num_classes: int = 80, device=None,
+                       exp_path: str = None, ckpt_path: str = None) -> nn.Module:
+    """creates and loads a YOLOX model
+
+    Args:
+        name (str): name of model. for example, "yolox-s", "yolox-tiny" or "yolox_custom"
+        if you want to load your own model.
+        pretrained (bool): load pretrained weights into the model. Default to True.
+        device (str): default device to for model. Default to None.
+        num_classes (int): number of model classes. Default to 80.
+        exp_path (str): path to your own experiment file. Required if name="yolox_custom"
+        ckpt_path (str): path to your own ckpt. Required if name="yolox_custom" and you want to
+            load a pretrained model
+
+
+    Returns:
+        YOLOX model (nn.Module)
+    """
+    from yolox.exp import get_exp, Exp
+
+    if device is None:
+        device = "cuda:0" if torch.cuda.is_available() else "cpu"
+    device = torch.device(device)
+
+    assert name in _CKPT_FULL_PATH or name == "yolox_custom", \
+        f"user should use one of value in {_CKPT_FULL_PATH.keys()} or \"yolox_custom\""
+    if name in _CKPT_FULL_PATH:
+        exp: Exp = get_exp(exp_name=name)
+        exp.num_classes = num_classes
+        yolox_model = exp.get_model()
+        if pretrained and num_classes == 80:
+            weights_url = _CKPT_FULL_PATH[name]
+            ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
+            if "model" in ckpt:
+                ckpt = ckpt["model"]
+            yolox_model.load_state_dict(ckpt)
+    else:
+        assert exp_path is not None, "for a \"yolox_custom\" model exp_path must be provided"
+        exp: Exp = get_exp(exp_file=exp_path)
+        yolox_model = exp.get_model()
+        if ckpt_path:
+            ckpt = torch.load(ckpt_path, map_location="cpu")
+            if "model" in ckpt:
+                ckpt = ckpt["model"]
+            yolox_model.load_state_dict(ckpt)
+
+    yolox_model.to(device)
+    return yolox_model
+
+
+def yolox_nano(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
+    return create_yolox_model("yolox-nano", pretrained, num_classes, device)
+
+
+def yolox_tiny(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
+    return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
+
+
+def yolox_s(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
+    return create_yolox_model("yolox-s", pretrained, num_classes, device)
+
+
+def yolox_m(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
+    return create_yolox_model("yolox-m", pretrained, num_classes, device)
+
+
+def yolox_l(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
+    return create_yolox_model("yolox-l", pretrained, num_classes, device)
+
+
+def yolox_x(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
+    return create_yolox_model("yolox-x", pretrained, num_classes, device)
+
+
+def yolov3(pretrained: bool = True, num_classes: int = 80, device: str = None) -> nn.Module:
+    return create_yolox_model("yolov3", pretrained, num_classes, device)
+
+
+def yolox_custom(ckpt_path: str = None, exp_path: str = None, device: str = None) -> nn.Module:
+    return create_yolox_model("yolox_custom", ckpt_path=ckpt_path, exp_path=exp_path, device=device)
diff --git a/multimodal/YOLOX/yolox/models/darknet.py b/multimodal/YOLOX/yolox/models/darknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3e053f163ade7b69979bcec86532466ab67eedf
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/darknet.py
@@ -0,0 +1,179 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+from torch import nn
+
+from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
+
+
+class Darknet(nn.Module):
+    # number of blocks from dark2 to dark5.
+    depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
+
+    def __init__(
+        self,
+        depth,
+        in_channels=3,
+        stem_out_channels=32,
+        out_features=("dark3", "dark4", "dark5"),
+    ):
+        """
+        Args:
+            depth (int): depth of darknet used in model, usually use [21, 53] for this param.
+            in_channels (int): number of input channels, for example, use 3 for RGB image.
+            stem_out_channels (int): number of output channels of darknet stem.
+                It decides channels of darknet layer2 to layer5.
+            out_features (Tuple[str]): desired output layer name.
+        """
+        super().__init__()
+        assert out_features, "please provide output features of Darknet"
+        self.out_features = out_features
+        self.stem = nn.Sequential(
+            BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
+            *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
+        )
+        in_channels = stem_out_channels * 2  # 64
+
+        num_blocks = Darknet.depth2blocks[depth]
+        # create darknet with `stem_out_channels` and `num_blocks` layers.
+        # to make model structure more clear, we don't use `for` statement in python.
+        self.dark2 = nn.Sequential(
+            *self.make_group_layer(in_channels, num_blocks[0], stride=2)
+        )
+        in_channels *= 2  # 128
+        self.dark3 = nn.Sequential(
+            *self.make_group_layer(in_channels, num_blocks[1], stride=2)
+        )
+        in_channels *= 2  # 256
+        self.dark4 = nn.Sequential(
+            *self.make_group_layer(in_channels, num_blocks[2], stride=2)
+        )
+        in_channels *= 2  # 512
+
+        self.dark5 = nn.Sequential(
+            *self.make_group_layer(in_channels, num_blocks[3], stride=2),
+            *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
+        )
+
+    def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
+        "starts with conv layer then has `num_blocks` `ResLayer`"
+        return [
+            BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
+            *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)],
+        ]
+
+    def make_spp_block(self, filters_list, in_filters):
+        m = nn.Sequential(
+            *[
+                BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
+                BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
+                SPPBottleneck(
+                    in_channels=filters_list[1],
+                    out_channels=filters_list[0],
+                    activation="lrelu",
+                ),
+                BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
+                BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
+            ]
+        )
+        return m
+
+    def forward(self, x):
+        outputs = {}
+        x = self.stem(x)
+        outputs["stem"] = x
+        x = self.dark2(x)
+        outputs["dark2"] = x
+        x = self.dark3(x)
+        outputs["dark3"] = x
+        x = self.dark4(x)
+        outputs["dark4"] = x
+        x = self.dark5(x)
+        outputs["dark5"] = x
+        return {k: v for k, v in outputs.items() if k in self.out_features}
+
+
+class CSPDarknet(nn.Module):
+    def __init__(
+        self,
+        dep_mul,
+        wid_mul,
+        out_features=("dark3", "dark4", "dark5"),
+        depthwise=False,
+        act="silu",
+    ):
+        super().__init__()
+        assert out_features, "please provide output features of Darknet"
+        self.out_features = out_features
+        Conv = DWConv if depthwise else BaseConv
+
+        base_channels = int(wid_mul * 64)  # 64
+        base_depth = max(round(dep_mul * 3), 1)  # 3
+
+        # stem
+        self.stem = Focus(3, base_channels, ksize=3, act=act)
+
+        # dark2
+        self.dark2 = nn.Sequential(
+            Conv(base_channels, base_channels * 2, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 2,
+                base_channels * 2,
+                n=base_depth,
+                depthwise=depthwise,
+                act=act,
+            ),
+        )
+
+        # dark3
+        self.dark3 = nn.Sequential(
+            Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 4,
+                base_channels * 4,
+                n=base_depth * 3,
+                depthwise=depthwise,
+                act=act,
+            ),
+        )
+
+        # dark4
+        self.dark4 = nn.Sequential(
+            Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
+            CSPLayer(
+                base_channels * 8,
+                base_channels * 8,
+                n=base_depth * 3,
+                depthwise=depthwise,
+                act=act,
+            ),
+        )
+
+        # dark5
+        self.dark5 = nn.Sequential(
+            Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
+            SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
+            CSPLayer(
+                base_channels * 16,
+                base_channels * 16,
+                n=base_depth,
+                shortcut=False,
+                depthwise=depthwise,
+                act=act,
+            ),
+        )
+
+    def forward(self, x):
+        outputs = {}
+        x = self.stem(x)
+        outputs["stem"] = x
+        x = self.dark2(x)
+        outputs["dark2"] = x
+        x = self.dark3(x)
+        outputs["dark3"] = x
+        x = self.dark4(x)
+        outputs["dark4"] = x
+        x = self.dark5(x)
+        outputs["dark5"] = x
+        return {k: v for k, v in outputs.items() if k in self.out_features}
diff --git a/multimodal/YOLOX/yolox/models/losses.py b/multimodal/YOLOX/yolox/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..77b4d8ef7660880031f4ef23c82ba3a85b6fd254
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/losses.py
@@ -0,0 +1,53 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+
+class IOUloss(nn.Module):
+    def __init__(self, reduction="none", loss_type="iou"):
+        super(IOUloss, self).__init__()
+        self.reduction = reduction
+        self.loss_type = loss_type
+
+    def forward(self, pred, target):
+        assert pred.shape[0] == target.shape[0]
+
+        pred = pred.view(-1, 4)
+        target = target.view(-1, 4)
+        tl = torch.max(
+            (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
+        )
+        br = torch.min(
+            (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
+        )
+
+        area_p = torch.prod(pred[:, 2:], 1)
+        area_g = torch.prod(target[:, 2:], 1)
+
+        en = (tl < br).type(tl.type()).prod(dim=1)
+        area_i = torch.prod(br - tl, 1) * en
+        area_u = area_p + area_g - area_i
+        iou = (area_i) / (area_u + 1e-16)
+
+        if self.loss_type == "iou":
+            loss = 1 - iou ** 2
+        elif self.loss_type == "giou":
+            c_tl = torch.min(
+                (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
+            )
+            c_br = torch.max(
+                (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
+            )
+            area_c = torch.prod(c_br - c_tl, 1)
+            giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
+            loss = 1 - giou.clamp(min=-1.0, max=1.0)
+
+        if self.reduction == "mean":
+            loss = loss.mean()
+        elif self.reduction == "sum":
+            loss = loss.sum()
+
+        return loss
diff --git a/multimodal/YOLOX/yolox/models/network_blocks.py b/multimodal/YOLOX/yolox/models/network_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..68aacfc33208eab072422e0647742006984dfdfd
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/network_blocks.py
@@ -0,0 +1,210 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_activation(name="silu", inplace=True):
+    if name == "silu":
+        module = nn.SiLU(inplace=inplace)
+    elif name == "relu":
+        module = nn.ReLU(inplace=inplace)
+    elif name == "lrelu":
+        module = nn.LeakyReLU(0.1, inplace=inplace)
+    else:
+        raise AttributeError("Unsupported act type: {}".format(name))
+    return module
+
+
+class BaseConv(nn.Module):
+    """A Conv2d -> Batchnorm -> silu/leaky relu block"""
+
+    def __init__(
+        self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
+    ):
+        super().__init__()
+        # same padding
+        pad = (ksize - 1) // 2
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size=ksize,
+            stride=stride,
+            padding=pad,
+            groups=groups,
+            bias=bias,
+        )
+        self.bn = nn.BatchNorm2d(out_channels)
+        self.act = get_activation(act, inplace=True)
+
+    def forward(self, x):
+        return self.act(self.bn(self.conv(x)))
+
+    def fuseforward(self, x):
+        return self.act(self.conv(x))
+
+
+class DWConv(nn.Module):
+    """Depthwise Conv + Conv"""
+
+    def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
+        super().__init__()
+        self.dconv = BaseConv(
+            in_channels,
+            in_channels,
+            ksize=ksize,
+            stride=stride,
+            groups=in_channels,
+            act=act,
+        )
+        self.pconv = BaseConv(
+            in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
+        )
+
+    def forward(self, x):
+        x = self.dconv(x)
+        return self.pconv(x)
+
+
+class Bottleneck(nn.Module):
+    # Standard bottleneck
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        shortcut=True,
+        expansion=0.5,
+        depthwise=False,
+        act="silu",
+    ):
+        super().__init__()
+        hidden_channels = int(out_channels * expansion)
+        Conv = DWConv if depthwise else BaseConv
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
+        self.use_add = shortcut and in_channels == out_channels
+
+    def forward(self, x):
+        y = self.conv2(self.conv1(x))
+        if self.use_add:
+            y = y + x
+        return y
+
+
+class ResLayer(nn.Module):
+    "Residual layer with `in_channels` inputs."
+
+    def __init__(self, in_channels: int):
+        super().__init__()
+        mid_channels = in_channels // 2
+        self.layer1 = BaseConv(
+            in_channels, mid_channels, ksize=1, stride=1, act="lrelu"
+        )
+        self.layer2 = BaseConv(
+            mid_channels, in_channels, ksize=3, stride=1, act="lrelu"
+        )
+
+    def forward(self, x):
+        out = self.layer2(self.layer1(x))
+        return x + out
+
+
+class SPPBottleneck(nn.Module):
+    """Spatial pyramid pooling layer used in YOLOv3-SPP"""
+
+    def __init__(
+        self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"
+    ):
+        super().__init__()
+        hidden_channels = in_channels // 2
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
+        self.m = nn.ModuleList(
+            [
+                nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
+                for ks in kernel_sizes
+            ]
+        )
+        conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
+        self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = torch.cat([x] + [m(x) for m in self.m], dim=1)
+        x = self.conv2(x)
+        return x
+
+
+class CSPLayer(nn.Module):
+    """C3 in yolov5, CSP Bottleneck with 3 convolutions"""
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        n=1,
+        shortcut=True,
+        expansion=0.5,
+        depthwise=False,
+        act="silu",
+    ):
+        """
+        Args:
+            in_channels (int): input channels.
+            out_channels (int): output channels.
+            n (int): number of Bottlenecks. Default value: 1.
+        """
+        # ch_in, ch_out, number, shortcut, groups, expansion
+        super().__init__()
+        hidden_channels = int(out_channels * expansion)  # hidden channels
+        self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
+        self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
+        module_list = [
+            Bottleneck(
+                hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
+            )
+            for _ in range(n)
+        ]
+        self.m = nn.Sequential(*module_list)
+
+    def forward(self, x):
+        x_1 = self.conv1(x)
+        x_2 = self.conv2(x)
+        x_1 = self.m(x_1)
+        x = torch.cat((x_1, x_2), dim=1)
+        return self.conv3(x)
+
+
+class Focus(nn.Module):
+    """Focus width and height information into channel space."""
+
+    def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
+        super().__init__()
+        self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
+
+    def forward(self, x):
+        # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
+        patch_top_left = x[..., ::2, ::2]
+        patch_top_right = x[..., ::2, 1::2]
+        patch_bot_left = x[..., 1::2, ::2]
+        patch_bot_right = x[..., 1::2, 1::2]
+        x = torch.cat(
+            (
+                patch_top_left,
+                patch_bot_left,
+                patch_top_right,
+                patch_bot_right,
+            ),
+            dim=1,
+        )
+        return self.conv(x)
diff --git a/multimodal/YOLOX/yolox/models/yolo_fpn.py b/multimodal/YOLOX/yolox/models/yolo_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..224271f59fd55b1e8e4bf3321d746a85bfe0b09c
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/yolo_fpn.py
@@ -0,0 +1,84 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+from .darknet import Darknet
+from .network_blocks import BaseConv
+
+
+class YOLOFPN(nn.Module):
+    """
+    YOLOFPN module. Darknet 53 is the default backbone of this model.
+    """
+
+    def __init__(
+        self,
+        depth=53,
+        in_features=["dark3", "dark4", "dark5"],
+    ):
+        super().__init__()
+
+        self.backbone = Darknet(depth)
+        self.in_features = in_features
+
+        # out 1
+        self.out1_cbl = self._make_cbl(512, 256, 1)
+        self.out1 = self._make_embedding([256, 512], 512 + 256)
+
+        # out 2
+        self.out2_cbl = self._make_cbl(256, 128, 1)
+        self.out2 = self._make_embedding([128, 256], 256 + 128)
+
+        # upsample
+        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+
+    def _make_cbl(self, _in, _out, ks):
+        return BaseConv(_in, _out, ks, stride=1, act="lrelu")
+
+    def _make_embedding(self, filters_list, in_filters):
+        m = nn.Sequential(
+            *[
+                self._make_cbl(in_filters, filters_list[0], 1),
+                self._make_cbl(filters_list[0], filters_list[1], 3),
+                self._make_cbl(filters_list[1], filters_list[0], 1),
+                self._make_cbl(filters_list[0], filters_list[1], 3),
+                self._make_cbl(filters_list[1], filters_list[0], 1),
+            ]
+        )
+        return m
+
+    def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
+        with open(filename, "rb") as f:
+            state_dict = torch.load(f, map_location="cpu")
+        print("loading pretrained weights...")
+        self.backbone.load_state_dict(state_dict)
+
+    def forward(self, inputs):
+        """
+        Args:
+            inputs (Tensor): input image.
+
+        Returns:
+            Tuple[Tensor]: FPN output features..
+        """
+        #  backbone
+        out_features = self.backbone(inputs)
+        x2, x1, x0 = [out_features[f] for f in self.in_features]
+
+        #  yolo branch 1
+        x1_in = self.out1_cbl(x0)
+        x1_in = self.upsample(x1_in)
+        x1_in = torch.cat([x1_in, x1], 1)
+        out_dark4 = self.out1(x1_in)
+
+        #  yolo branch 2
+        x2_in = self.out2_cbl(out_dark4)
+        x2_in = self.upsample(x2_in)
+        x2_in = torch.cat([x2_in, x2], 1)
+        out_dark3 = self.out2(x2_in)
+
+        outputs = (out_dark3, out_dark4, x0)
+        return outputs
diff --git a/multimodal/YOLOX/yolox/models/yolo_head.py b/multimodal/YOLOX/yolox/models/yolo_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a614bb1e95a6917341c9df8fc3ac87ec6260f6b0
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/yolo_head.py
@@ -0,0 +1,691 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import math
+from loguru import logger
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from yolox.utils import bboxes_iou, cxcywh2xyxy, meshgrid, visualize_assign
+
+from .losses import IOUloss
+from .network_blocks import BaseConv, DWConv
+
+
+class YOLOXHead(nn.Module):
+    def __init__(
+        self,
+        num_classes,
+        width=1.0,
+        strides=[8, 16, 32],
+        in_channels=[256, 512, 1024],
+        act="silu",
+        depthwise=False,
+    ):
+        """
+        Args:
+            act (str): activation type of conv. Defalut value: "silu".
+            depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
+        """
+        super().__init__()
+
+        self.num_classes = num_classes
+        self.decode_in_inference = True  # for deploy, set to False
+
+        self.cls_convs = nn.ModuleList()
+        self.reg_convs = nn.ModuleList()
+        self.cls_preds = nn.ModuleList()
+        self.reg_preds = nn.ModuleList()
+        self.obj_preds = nn.ModuleList()
+        self.stems = nn.ModuleList()
+        Conv = DWConv if depthwise else BaseConv
+
+        for i in range(len(in_channels)):
+            self.stems.append(
+                BaseConv(
+                    in_channels=int(in_channels[i] * width),
+                    out_channels=int(256 * width),
+                    ksize=1,
+                    stride=1,
+                    act=act,
+                )
+            )
+            self.cls_convs.append(
+                nn.Sequential(
+                    *[
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                    ]
+                )
+            )
+            self.reg_convs.append(
+                nn.Sequential(
+                    *[
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                        Conv(
+                            in_channels=int(256 * width),
+                            out_channels=int(256 * width),
+                            ksize=3,
+                            stride=1,
+                            act=act,
+                        ),
+                    ]
+                )
+            )
+            self.cls_preds.append(
+                nn.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=self.num_classes,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+            self.reg_preds.append(
+                nn.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=4,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+            self.obj_preds.append(
+                nn.Conv2d(
+                    in_channels=int(256 * width),
+                    out_channels=1,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                )
+            )
+
+        self.use_l1 = False
+        self.l1_loss = nn.L1Loss(reduction="none")
+        self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
+        self.iou_loss = IOUloss(reduction="none")
+        self.strides = strides
+        self.grids = [torch.zeros(1)] * len(in_channels)
+
+    def initialize_biases(self, prior_prob):
+        for conv in self.cls_preds:
+            b = conv.bias.view(1, -1)
+            b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
+            conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+        for conv in self.obj_preds:
+            b = conv.bias.view(1, -1)
+            b.data.fill_(-math.log((1 - prior_prob) / prior_prob))
+            conv.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+    def forward(self, xin, labels=None, imgs=None):
+        outputs = []
+        origin_preds = []
+        x_shifts = []
+        y_shifts = []
+        expanded_strides = []
+
+        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
+            zip(self.cls_convs, self.reg_convs, self.strides, xin)
+        ):
+            # print("before stems x", torch.isnan(x).any())
+            x = self.stems[k](x)
+            cls_x = x
+            reg_x = x
+
+            cls_feat = cls_conv(cls_x)
+            cls_output = self.cls_preds[k](cls_feat)
+
+            reg_feat = reg_conv(reg_x)
+            reg_output = self.reg_preds[k](reg_feat)
+            obj_output = self.obj_preds[k](reg_feat)
+
+
+            # DEBUG HERE
+            # print("="*80)
+            # print("x", torch.isnan(x).any())
+            # print("cls_feat", torch.isnan(cls_feat).any())
+            # print("reg_feat", torch.isnan(reg_feat).any())
+            # print("cls_output", torch.isnan(cls_output).any())
+            # print("reg_output", torch.isnan(reg_output).any())
+            # print("obj_output", torch.isnan(obj_output).any())
+            # if torch.isnan(obj_output).any():
+            #     if torch.distributed.get_rank() == 0:
+            #         import pdb; pdb.set_trace()
+            #     else:
+            #         torch.distributed.barrier()
+            # print("="*80)
+
+
+            if self.training:
+                output = torch.cat([reg_output, obj_output, cls_output], 1)
+                output, grid = self.get_output_and_grid(
+                    output, k, stride_this_level, xin[0].type()
+                )
+                x_shifts.append(grid[:, :, 0])
+                y_shifts.append(grid[:, :, 1])
+                expanded_strides.append(
+                    torch.zeros(1, grid.shape[1])
+                    .fill_(stride_this_level)
+                    .type_as(xin[0])
+                )
+                if self.use_l1:
+                    batch_size = reg_output.shape[0]
+                    hsize, wsize = reg_output.shape[-2:]
+                    reg_output = reg_output.view(
+                        batch_size, 1, 4, hsize, wsize
+                    )
+                    reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
+                        batch_size, -1, 4
+                    )
+                    origin_preds.append(reg_output.clone())
+
+            else:
+                output = torch.cat(
+                    [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
+                )
+
+            outputs.append(output)
+
+        if self.training:
+            return self.get_losses(
+                imgs,
+                x_shifts,
+                y_shifts,
+                expanded_strides,
+                labels,
+                torch.cat(outputs, 1),
+                origin_preds,
+                dtype=xin[0].dtype,
+            )
+        else:
+            self.hw = [x.shape[-2:] for x in outputs]
+            # [batch, n_anchors_all, 85]
+            outputs = torch.cat(
+                [x.flatten(start_dim=2) for x in outputs], dim=2
+            ).permute(0, 2, 1)
+            if self.decode_in_inference:
+                return self.decode_outputs(outputs, dtype=xin[0].type())
+            else:
+                return outputs
+
+    def get_output_and_grid(self, output, k, stride, dtype):
+        grid = self.grids[k]
+
+        batch_size = output.shape[0]
+        n_ch = 5 + self.num_classes
+        hsize, wsize = output.shape[-2:]
+        if grid.shape[2:4] != output.shape[2:4]:
+            yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
+            grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
+            self.grids[k] = grid
+
+        output = output.view(batch_size, 1, n_ch, hsize, wsize)
+        output = output.permute(0, 1, 3, 4, 2).reshape(
+            batch_size, hsize * wsize, -1
+        )
+        grid = grid.view(1, -1, 2)
+        output[..., :2] = (output[..., :2] + grid) * stride
+        output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
+        return output, grid
+
+    def decode_outputs(self, outputs, dtype):
+        grids = []
+        strides = []
+        for (hsize, wsize), stride in zip(self.hw, self.strides):
+            yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
+            grid = torch.stack((xv, yv), 2).view(1, -1, 2)
+            grids.append(grid)
+            shape = grid.shape[:2]
+            strides.append(torch.full((*shape, 1), stride))
+
+        grids = torch.cat(grids, dim=1).type(dtype)
+        strides = torch.cat(strides, dim=1).type(dtype)
+
+        outputs = torch.cat([
+            (outputs[..., 0:2] + grids) * strides,
+            torch.exp(outputs[..., 2:4]) * strides,
+            outputs[..., 4:]
+        ], dim=-1)
+        return outputs
+
+    def get_losses(
+        self,
+        imgs,
+        x_shifts,
+        y_shifts,
+        expanded_strides,
+        labels,
+        outputs,
+        origin_preds,
+        dtype,
+    ):
+        bbox_preds = outputs[:, :, :4]  # [batch, n_anchors_all, 4]
+        obj_preds = outputs[:, :, 4:5]  # [batch, n_anchors_all, 1]
+        cls_preds = outputs[:, :, 5:]  # [batch, n_anchors_all, n_cls]
+
+        # calculate targets
+        nlabel = (labels.sum(dim=2) > 0).sum(dim=1)  # number of objects
+
+        total_num_anchors = outputs.shape[1]
+        x_shifts = torch.cat(x_shifts, 1)  # [1, n_anchors_all]
+        y_shifts = torch.cat(y_shifts, 1)  # [1, n_anchors_all]
+        expanded_strides = torch.cat(expanded_strides, 1)
+        if self.use_l1:
+            origin_preds = torch.cat(origin_preds, 1)
+
+        cls_targets = []
+        reg_targets = []
+        l1_targets = []
+        obj_targets = []
+        fg_masks = []
+
+        num_fg = 0.0
+        num_gts = 0.0
+
+        for batch_idx in range(outputs.shape[0]):
+            num_gt = int(nlabel[batch_idx])
+            num_gts += num_gt
+            if num_gt == 0:
+                cls_target = outputs.new_zeros((0, self.num_classes))
+                reg_target = outputs.new_zeros((0, 4))
+                l1_target = outputs.new_zeros((0, 4))
+                obj_target = outputs.new_zeros((total_num_anchors, 1))
+                fg_mask = outputs.new_zeros(total_num_anchors).bool()
+            else:
+                gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5]
+                gt_classes = labels[batch_idx, :num_gt, 0]
+                bboxes_preds_per_image = bbox_preds[batch_idx]
+
+                try:
+                    (
+                        gt_matched_classes,
+                        fg_mask,
+                        pred_ious_this_matching,
+                        matched_gt_inds,
+                        num_fg_img,
+                    ) = self.get_assignments(  # noqa
+                        batch_idx,
+                        num_gt,
+                        gt_bboxes_per_image,
+                        gt_classes,
+                        bboxes_preds_per_image,
+                        expanded_strides,
+                        x_shifts,
+                        y_shifts,
+                        cls_preds,
+                        obj_preds,
+                    )
+                except RuntimeError as e:
+                    # TODO: the string might change, consider a better way
+                    if "CUDA out of memory. " not in str(e):
+                        raise  # RuntimeError might not caused by CUDA OOM
+
+                    logger.error(
+                        "OOM RuntimeError is raised due to the huge memory cost during label assignment. \
+                           CPU mode is applied in this batch. If you want to avoid this issue, \
+                           try to reduce the batch size or image size."
+                    )
+                    torch.cuda.empty_cache()
+                    (
+                        gt_matched_classes,
+                        fg_mask,
+                        pred_ious_this_matching,
+                        matched_gt_inds,
+                        num_fg_img,
+                    ) = self.get_assignments(  # noqa
+                        batch_idx,
+                        num_gt,
+                        gt_bboxes_per_image,
+                        gt_classes,
+                        bboxes_preds_per_image,
+                        expanded_strides,
+                        x_shifts,
+                        y_shifts,
+                        cls_preds,
+                        obj_preds,
+                        "cpu",
+                    )
+
+                if num_fg_img == 0:
+                    cls_target = outputs.new_zeros((0, self.num_classes))
+                    reg_target = outputs.new_zeros((0, 4))
+                    if self.use_l1:
+                        l1_target = outputs.new_zeros((0, 4))
+                    obj_target = outputs.new_zeros((total_num_anchors, 1))
+                    fg_mask = outputs.new_zeros(total_num_anchors).bool()
+                else:
+                    torch.cuda.empty_cache()
+                    num_fg += num_fg_img
+
+                    cls_target = F.one_hot(
+                        gt_matched_classes.to(torch.int64), self.num_classes
+                    ) * pred_ious_this_matching.unsqueeze(-1)
+                    obj_target = fg_mask.unsqueeze(-1)
+                    reg_target = gt_bboxes_per_image[matched_gt_inds]
+                    if self.use_l1:
+                        l1_target = self.get_l1_target(
+                            outputs.new_zeros((num_fg_img, 4)),
+                            gt_bboxes_per_image[matched_gt_inds],
+                            expanded_strides[0][fg_mask],
+                            x_shifts=x_shifts[0][fg_mask],
+                            y_shifts=y_shifts[0][fg_mask],
+                        )
+
+            cls_targets.append(cls_target)
+            reg_targets.append(reg_target)
+            obj_targets.append(obj_target.to(dtype))
+            fg_masks.append(fg_mask)
+            if self.use_l1:
+                l1_targets.append(l1_target)
+
+        cls_targets = torch.cat(cls_targets, 0)
+        reg_targets = torch.cat(reg_targets, 0)
+        obj_targets = torch.cat(obj_targets, 0)
+        fg_masks = torch.cat(fg_masks, 0)
+        if self.use_l1:
+            l1_targets = torch.cat(l1_targets, 0)
+
+        num_fg = max(num_fg, 1)
+        loss_iou = (
+            self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets)
+        ).sum() / num_fg
+        loss_obj = (
+            self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets)
+        ).sum() / num_fg
+        loss_cls = (
+            self.bcewithlog_loss(
+                cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets
+            )
+        ).sum() / num_fg
+        if self.use_l1:
+            loss_l1 = (
+                self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets)
+            ).sum() / num_fg
+        else:
+            loss_l1 = 0.0
+
+        reg_weight = 5.0
+        loss_cls *= 0.0
+        loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1
+
+        return (
+            loss,
+            reg_weight * loss_iou,
+            loss_obj,
+            loss_cls,
+            loss_l1,
+            num_fg / max(num_gts, 1),
+        )
+
+    def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
+        l1_target[:, 0] = gt[:, 0] / stride - x_shifts
+        l1_target[:, 1] = gt[:, 1] / stride - y_shifts
+        l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
+        l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
+        return l1_target
+
+    @torch.no_grad()
+    def get_assignments(
+        self,
+        batch_idx,
+        num_gt,
+        gt_bboxes_per_image,
+        gt_classes,
+        bboxes_preds_per_image,
+        expanded_strides,
+        x_shifts,
+        y_shifts,
+        cls_preds,
+        obj_preds,
+        mode="gpu",
+    ):
+
+        if mode == "cpu":
+            print("-----------Using CPU for the Current Batch-------------")
+            gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
+            bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
+            gt_classes = gt_classes.cpu().float()
+            expanded_strides = expanded_strides.cpu().float()
+            x_shifts = x_shifts.cpu()
+            y_shifts = y_shifts.cpu()
+
+        fg_mask, geometry_relation = self.get_geometry_constraint(
+            gt_bboxes_per_image,
+            expanded_strides,
+            x_shifts,
+            y_shifts,
+        )
+
+        # NOTE: Fix `selected index k out of range`
+        npa: int = fg_mask.sum().item()  # number of positive anchors
+        if npa == 0:
+            gt_matched_classes = torch.zeros(0, device=fg_mask.device).long()
+            pred_ious_this_matching = torch.rand(0, device=fg_mask.device)
+            matched_gt_inds = gt_matched_classes
+            num_fg = npa
+
+            if mode == "cpu":
+                gt_matched_classes = gt_matched_classes.cuda()
+                fg_mask = fg_mask.cuda()
+                pred_ious_this_matching = pred_ious_this_matching.cuda()
+                matched_gt_inds = matched_gt_inds.cuda()
+                num_fg = num_fg.cuda()
+
+            return (
+                gt_matched_classes,
+                fg_mask,
+                pred_ious_this_matching,
+                matched_gt_inds,
+                num_fg,
+            )
+
+        bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
+        cls_preds_ = cls_preds[batch_idx][fg_mask]
+        obj_preds_ = obj_preds[batch_idx][fg_mask]
+        num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
+
+        if mode == "cpu":
+            gt_bboxes_per_image = gt_bboxes_per_image.cpu()
+            bboxes_preds_per_image = bboxes_preds_per_image.cpu()
+
+        pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False)
+
+        gt_cls_per_image = (
+            F.one_hot(gt_classes.to(torch.int64), self.num_classes)
+            .float()
+        )
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
+
+        if mode == "cpu":
+            cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()
+
+        with torch.cuda.amp.autocast(enabled=False):
+            cls_preds_ = (
+                cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
+            ).sqrt()
+            pair_wise_cls_loss = F.binary_cross_entropy(
+                cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1),
+                gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1),
+                reduction="none"
+            ).sum(-1)
+        del cls_preds_
+
+        cost = (
+            pair_wise_cls_loss * 0.0
+            + 3.0 * pair_wise_ious_loss
+            + float(1e6) * (~geometry_relation)
+        )
+
+        (
+            num_fg,
+            gt_matched_classes,
+            pred_ious_this_matching,
+            matched_gt_inds,
+        ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
+        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
+
+        if mode == "cpu":
+            gt_matched_classes = gt_matched_classes.cuda()
+            fg_mask = fg_mask.cuda()
+            pred_ious_this_matching = pred_ious_this_matching.cuda()
+            matched_gt_inds = matched_gt_inds.cuda()
+
+        return (
+            gt_matched_classes,
+            fg_mask,
+            pred_ious_this_matching,
+            matched_gt_inds,
+            num_fg,
+        )
+
+    def get_geometry_constraint(
+        self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
+    ):
+        """
+        Calculate whether the center of an object is located in a fixed range of
+        an anchor. This is used to avert inappropriate matching. It can also reduce
+        the number of candidate anchors so that the GPU memory is saved.
+        """
+        expanded_strides_per_image = expanded_strides[0]
+        x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
+        y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0)
+
+        # in fixed center
+        center_radius = 1.5
+        center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius
+        gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist
+        gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist
+        gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist
+        gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist
+
+        c_l = x_centers_per_image - gt_bboxes_per_image_l
+        c_r = gt_bboxes_per_image_r - x_centers_per_image
+        c_t = y_centers_per_image - gt_bboxes_per_image_t
+        c_b = gt_bboxes_per_image_b - y_centers_per_image
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        anchor_filter = is_in_centers.sum(dim=0) > 0
+        geometry_relation = is_in_centers[:, anchor_filter]
+
+        return anchor_filter, geometry_relation
+
+    def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        n_candidate_k = min(10, pair_wise_ious.size(1))
+        topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1)
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+        for gt_idx in range(num_gt):
+            _, pos_idx = torch.topk(
+                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
+            )
+            matching_matrix[gt_idx][pos_idx] = 1
+
+        del topk_ious, dynamic_ks, pos_idx
+
+        anchor_matching_gt = matching_matrix.sum(0)
+        # deal with the case that one anchor matches multiple ground-truths
+        if anchor_matching_gt.max() > 1:
+            multiple_match_mask = anchor_matching_gt > 1
+            _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0)
+            matching_matrix[:, multiple_match_mask] *= 0
+            matching_matrix[cost_argmin, multiple_match_mask] = 1
+        fg_mask_inboxes = anchor_matching_gt > 0
+        num_fg = fg_mask_inboxes.sum().item()
+
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        gt_matched_classes = gt_classes[matched_gt_inds]
+
+        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
+            fg_mask_inboxes
+        ]
+        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
+
+    def visualize_assign_result(self, xin, labels=None, imgs=None, save_prefix="assign_vis_"):
+        # original forward logic
+        outputs, x_shifts, y_shifts, expanded_strides = [], [], [], []
+        # TODO: use forward logic here.
+
+        for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
+            zip(self.cls_convs, self.reg_convs, self.strides, xin)
+        ):
+            x = self.stems[k](x)
+            cls_x = x
+            reg_x = x
+
+            cls_feat = cls_conv(cls_x)
+            cls_output = self.cls_preds[k](cls_feat)
+            reg_feat = reg_conv(reg_x)
+            reg_output = self.reg_preds[k](reg_feat)
+            obj_output = self.obj_preds[k](reg_feat)
+
+            output = torch.cat([reg_output, obj_output, cls_output], 1)
+            output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type())
+            x_shifts.append(grid[:, :, 0])
+            y_shifts.append(grid[:, :, 1])
+            expanded_strides.append(
+                torch.full((1, grid.shape[1]), stride_this_level).type_as(xin[0])
+            )
+            outputs.append(output)
+
+        outputs = torch.cat(outputs, 1)
+        bbox_preds = outputs[:, :, :4]  # [batch, n_anchors_all, 4]
+        obj_preds = outputs[:, :, 4:5]  # [batch, n_anchors_all, 1]
+        cls_preds = outputs[:, :, 5:]  # [batch, n_anchors_all, n_cls]
+
+        # calculate targets
+        total_num_anchors = outputs.shape[1]
+        x_shifts = torch.cat(x_shifts, 1)  # [1, n_anchors_all]
+        y_shifts = torch.cat(y_shifts, 1)  # [1, n_anchors_all]
+        expanded_strides = torch.cat(expanded_strides, 1)
+
+        nlabel = (labels.sum(dim=2) > 0).sum(dim=1)  # number of objects
+        for batch_idx, (img, num_gt, label) in enumerate(zip(imgs, nlabel, labels)):
+            img = imgs[batch_idx].permute(1, 2, 0).to(torch.uint8)
+            num_gt = int(num_gt)
+            if num_gt == 0:
+                fg_mask = outputs.new_zeros(total_num_anchors).bool()
+            else:
+                gt_bboxes_per_image = label[:num_gt, 1:5]
+                gt_classes = label[:num_gt, 0]
+                bboxes_preds_per_image = bbox_preds[batch_idx]
+                _, fg_mask, _, matched_gt_inds, _ = self.get_assignments(  # noqa
+                    batch_idx, num_gt, gt_bboxes_per_image, gt_classes,
+                    bboxes_preds_per_image, expanded_strides, x_shifts,
+                    y_shifts, cls_preds, obj_preds,
+                )
+
+            img = img.cpu().numpy().copy()  # copy is crucial here
+            coords = torch.stack([
+                ((x_shifts + 0.5) * expanded_strides).flatten()[fg_mask],
+                ((y_shifts + 0.5) * expanded_strides).flatten()[fg_mask],
+            ], 1)
+
+            xyxy_boxes = cxcywh2xyxy(gt_bboxes_per_image)
+            save_name = save_prefix + str(batch_idx) + ".png"
+            img = visualize_assign(img, xyxy_boxes, coords, matched_gt_inds, save_name)
+            logger.info(f"save img to {save_name}")
diff --git a/multimodal/YOLOX/yolox/models/yolo_pafpn.py b/multimodal/YOLOX/yolox/models/yolo_pafpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c4e18a5c3273ecdd878444cc42965e6a24a0cd1
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/yolo_pafpn.py
@@ -0,0 +1,116 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import torch
+import torch.nn as nn
+
+from .darknet import CSPDarknet
+from .network_blocks import BaseConv, CSPLayer, DWConv
+
+
+class YOLOPAFPN(nn.Module):
+    """
+    YOLOv3 model. Darknet 53 is the default backbone of this model.
+    """
+
+    def __init__(
+        self,
+        depth=1.0,
+        width=1.0,
+        in_features=("dark3", "dark4", "dark5"),
+        in_channels=[256, 512, 1024],
+        depthwise=False,
+        act="silu",
+    ):
+        super().__init__()
+        self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
+        self.in_features = in_features
+        self.in_channels = in_channels
+        Conv = DWConv if depthwise else BaseConv
+
+        self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
+        self.lateral_conv0 = BaseConv(
+            int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
+        )
+        self.C3_p4 = CSPLayer(
+            int(2 * in_channels[1] * width),
+            int(in_channels[1] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )  # cat
+
+        self.reduce_conv1 = BaseConv(
+            int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
+        )
+        self.C3_p3 = CSPLayer(
+            int(2 * in_channels[0] * width),
+            int(in_channels[0] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+        # bottom-up conv
+        self.bu_conv2 = Conv(
+            int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
+        )
+        self.C3_n3 = CSPLayer(
+            int(2 * in_channels[0] * width),
+            int(in_channels[1] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+        # bottom-up conv
+        self.bu_conv1 = Conv(
+            int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
+        )
+        self.C3_n4 = CSPLayer(
+            int(2 * in_channels[1] * width),
+            int(in_channels[2] * width),
+            round(3 * depth),
+            False,
+            depthwise=depthwise,
+            act=act,
+        )
+
+    def forward(self, input):
+        """
+        Args:
+            inputs: input images.
+
+        Returns:
+            Tuple[Tensor]: FPN feature.
+        """
+
+        #  backbone
+        out_features = self.backbone(input)
+        features = [out_features[f] for f in self.in_features]
+        [x2, x1, x0] = features
+
+        fpn_out0 = self.lateral_conv0(x0)  # 1024->512/32
+        f_out0 = self.upsample(fpn_out0)  # 512/16
+        f_out0 = torch.cat([f_out0, x1], 1)  # 512->1024/16
+        f_out0 = self.C3_p4(f_out0)  # 1024->512/16
+
+        fpn_out1 = self.reduce_conv1(f_out0)  # 512->256/16
+        f_out1 = self.upsample(fpn_out1)  # 256/8
+        f_out1 = torch.cat([f_out1, x2], 1)  # 256->512/8
+        pan_out2 = self.C3_p3(f_out1)  # 512->256/8
+
+        p_out1 = self.bu_conv2(pan_out2)  # 256->256/16
+        p_out1 = torch.cat([p_out1, fpn_out1], 1)  # 256->512/16
+        pan_out1 = self.C3_n3(p_out1)  # 512->512/16
+
+        p_out0 = self.bu_conv1(pan_out1)  # 512->512/32
+        p_out0 = torch.cat([p_out0, fpn_out0], 1)  # 512->1024/32
+        pan_out0 = self.C3_n4(p_out0)  # 1024->1024/32
+
+        outputs = (pan_out2, pan_out1, pan_out0)
+        return outputs
diff --git a/multimodal/YOLOX/yolox/models/yolox.py b/multimodal/YOLOX/yolox/models/yolox.py
new file mode 100644
index 0000000000000000000000000000000000000000..744ceea818e8f92ae422288ce7efba9842d9e28c
--- /dev/null
+++ b/multimodal/YOLOX/yolox/models/yolox.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import torch.nn as nn
+
+from .yolo_head import YOLOXHead
+from .yolo_pafpn import YOLOPAFPN
+
+
+class YOLOX(nn.Module):
+    """
+    YOLOX model module. The module list is defined by create_yolov3_modules function.
+    The network returns loss values from three YOLO layers during training
+    and detection results during test.
+    """
+
+    def __init__(self, backbone=None, head=None):
+        super().__init__()
+        if backbone is None:
+            backbone = YOLOPAFPN()
+        if head is None:
+            head = YOLOXHead(80)
+
+        self.backbone = backbone
+        self.head = head
+
+    def forward(self, x, targets=None):
+        # fpn output content features of [dark3, dark4, dark5]
+        fpn_outs = self.backbone(x)
+
+        if self.training:
+            assert targets is not None
+            loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
+                fpn_outs, targets, x
+            )
+            outputs = {
+                "total_loss": loss,
+                "iou_loss": iou_loss,
+                "l1_loss": l1_loss,
+                "conf_loss": conf_loss,
+                "cls_loss": cls_loss,
+                "num_fg": num_fg,
+            }
+        else:
+            outputs = self.head(fpn_outs)
+
+        return outputs
+
+    def visualize(self, x, targets, save_prefix="assign_vis_"):
+        fpn_outs = self.backbone(x)
+        self.head.visualize_assign_result(fpn_outs, targets, x, save_prefix)
diff --git a/multimodal/YOLOX/yolox/tools/__init__.py b/multimodal/YOLOX/yolox/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0944290b8d12c660ad8068d0b40ee1dbf8fd5938
--- /dev/null
+++ b/multimodal/YOLOX/yolox/tools/__init__.py
@@ -0,0 +1,27 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii Inc. All rights reserved.
+
+# This file is used for package installation. Script of train/eval/export will be available.
+
+import sys
+from importlib import abc, util
+from pathlib import Path
+
+_TOOLS_PATH = Path(__file__).resolve().parent.parent.parent / "tools"
+
+if _TOOLS_PATH.is_dir():
+    # This is true only for in-place installation (pip install -e, setup.py develop),
+    # where setup(package_dir=) does not work: https://github.com/pypa/setuptools/issues/230
+
+    class _PathFinder(abc.MetaPathFinder):
+
+        def find_spec(self, name, path, target=None):
+            if not name.startswith("yolox.tools."):
+                return
+            project_name = name.split(".")[-1] + ".py"
+            target_file = _TOOLS_PATH / project_name
+            if not target_file.is_file():
+                return
+            return util.spec_from_file_location(name, target_file)
+
+    sys.meta_path.append(_PathFinder())
diff --git a/multimodal/YOLOX/yolox/utils/__init__.py b/multimodal/YOLOX/yolox/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..08e6dae986b367ec1806c271b0c371cd17e89133
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/__init__.py
@@ -0,0 +1,16 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii Inc. All rights reserved.
+
+from .allreduce_norm import *
+from .boxes import *
+from .checkpoint import load_ckpt, save_checkpoint
+from .compat import meshgrid
+from .demo_utils import *
+from .dist import *
+from .ema import *
+from .logger import WandbLogger, setup_logger
+from .lr_scheduler import LRScheduler
+from .metric import *
+from .model_utils import *
+from .setup_env import *
+from .visualize import *
diff --git a/multimodal/YOLOX/yolox/utils/allreduce_norm.py b/multimodal/YOLOX/yolox/utils/allreduce_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..142c76c78061db6e2c5f4b899bcc5e2f2214f010
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/allreduce_norm.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import pickle
+from collections import OrderedDict
+
+import torch
+from torch import distributed as dist
+from torch import nn
+
+from .dist import _get_global_gloo_group, get_world_size
+
+ASYNC_NORM = (
+    nn.BatchNorm1d,
+    nn.BatchNorm2d,
+    nn.BatchNorm3d,
+    nn.InstanceNorm1d,
+    nn.InstanceNorm2d,
+    nn.InstanceNorm3d,
+)
+
+__all__ = [
+    "get_async_norm_states",
+    "pyobj2tensor",
+    "tensor2pyobj",
+    "all_reduce",
+    "all_reduce_norm",
+]
+
+
+def get_async_norm_states(module):
+    async_norm_states = OrderedDict()
+    for name, child in module.named_modules():
+        if isinstance(child, ASYNC_NORM):
+            for k, v in child.state_dict().items():
+                async_norm_states[".".join([name, k])] = v
+    return async_norm_states
+
+
+def pyobj2tensor(pyobj, device="cuda"):
+    """serialize picklable python object to tensor"""
+    storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj))
+    return torch.ByteTensor(storage).to(device=device)
+
+
+def tensor2pyobj(tensor):
+    """deserialize tensor to picklable python object"""
+    return pickle.loads(tensor.cpu().numpy().tobytes())
+
+
+def _get_reduce_op(op_name):
+    return {
+        "sum": dist.ReduceOp.SUM,
+        "mean": dist.ReduceOp.SUM,
+    }[op_name.lower()]
+
+
+def all_reduce(py_dict, op="sum", group=None):
+    """
+    Apply all reduce function for python dict object.
+    NOTE: make sure that every py_dict has the same keys and values are in the same shape.
+
+    Args:
+        py_dict (dict): dict to apply all reduce op.
+        op (str): operator, could be "sum" or "mean".
+    """
+    world_size = get_world_size()
+    if world_size == 1:
+        return py_dict
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group) == 1:
+        return py_dict
+
+    # all reduce logic across different devices.
+    py_key = list(py_dict.keys())
+    py_key_tensor = pyobj2tensor(py_key)
+    dist.broadcast(py_key_tensor, src=0)
+    py_key = tensor2pyobj(py_key_tensor)
+
+    tensor_shapes = [py_dict[k].shape for k in py_key]
+    tensor_numels = [py_dict[k].numel() for k in py_key]
+
+    flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key])
+    dist.all_reduce(flatten_tensor, op=_get_reduce_op(op))
+    if op == "mean":
+        flatten_tensor /= world_size
+
+    split_tensors = [
+        x.reshape(shape)
+        for x, shape in zip(torch.split(flatten_tensor, tensor_numels), tensor_shapes)
+    ]
+    return OrderedDict({k: v for k, v in zip(py_key, split_tensors)})
+
+
+def all_reduce_norm(module):
+    """
+    All reduce norm statistics in different devices.
+    """
+    states = get_async_norm_states(module)
+    states = all_reduce(states, op="mean")
+    module.load_state_dict(states, strict=False)
diff --git a/multimodal/YOLOX/yolox/utils/boxes.py b/multimodal/YOLOX/yolox/utils/boxes.py
new file mode 100644
index 0000000000000000000000000000000000000000..f71e8d90b67bc8d67644880e0c29b5f87c99b043
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/boxes.py
@@ -0,0 +1,143 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import numpy as np
+
+import torch
+import torchvision
+
+__all__ = [
+    "filter_box",
+    "postprocess",
+    "bboxes_iou",
+    "matrix_iou",
+    "adjust_box_anns",
+    "xyxy2xywh",
+    "xyxy2cxcywh",
+    "cxcywh2xyxy",
+]
+
+
+def filter_box(output, scale_range):
+    """
+    output: (N, 5+class) shape
+    """
+    min_scale, max_scale = scale_range
+    w = output[:, 2] - output[:, 0]
+    h = output[:, 3] - output[:, 1]
+    keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale)
+    return output[keep]
+
+
+def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
+    box_corner = prediction.new(prediction.shape)
+    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
+    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
+    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
+    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
+    prediction[:, :, :4] = box_corner[:, :, :4]
+
+    output = [None for _ in range(len(prediction))]
+    for i, image_pred in enumerate(prediction):
+
+        # If none are remaining => process next image
+        if not image_pred.size(0):
+            continue
+        # Get score and class with highest confidence
+        class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)
+
+        conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
+        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
+        detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
+        detections = detections[conf_mask]
+        if not detections.size(0):
+            continue
+
+        if class_agnostic:
+            nms_out_index = torchvision.ops.nms(
+                detections[:, :4],
+                detections[:, 4] * detections[:, 5],
+                nms_thre,
+            )
+        else:
+            nms_out_index = torchvision.ops.batched_nms(
+                detections[:, :4],
+                detections[:, 4] * detections[:, 5],
+                detections[:, 6],
+                nms_thre,
+            )
+
+        detections = detections[nms_out_index]
+        if output[i] is None:
+            output[i] = detections
+        else:
+            output[i] = torch.cat((output[i], detections))
+
+    return output
+
+
+def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
+    if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
+        raise IndexError
+
+    if xyxy:
+        tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
+        br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
+        area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
+        area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
+    else:
+        tl = torch.max(
+            (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
+            (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
+        )
+        br = torch.min(
+            (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
+            (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
+        )
+
+        area_a = torch.prod(bboxes_a[:, 2:], 1)
+        area_b = torch.prod(bboxes_b[:, 2:], 1)
+    en = (tl < br).type(tl.type()).prod(dim=2)
+    area_i = torch.prod(br - tl, 2) * en  # * ((tl < br).all())
+    return area_i / (area_a[:, None] + area_b - area_i)
+
+
+def matrix_iou(a, b):
+    """
+    return iou of a and b, numpy version for data augenmentation
+    """
+    lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+    rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+    area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+    area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+    area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+    return area_i / (area_a[:, np.newaxis] + area_b - area_i + 1e-12)
+
+
+def adjust_box_anns(bbox, scale_ratio, padw, padh, w_max, h_max):
+    bbox[:, 0::2] = np.clip(bbox[:, 0::2] * scale_ratio + padw, 0, w_max)
+    bbox[:, 1::2] = np.clip(bbox[:, 1::2] * scale_ratio + padh, 0, h_max)
+    return bbox
+
+
+def xyxy2xywh(bboxes):
+    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
+    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
+    return bboxes
+
+
+def xyxy2cxcywh(bboxes):
+    bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
+    bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
+    bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
+    bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
+    return bboxes
+
+
+def cxcywh2xyxy(bboxes):
+    bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
+    bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
+    bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
+    bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
+    return bboxes
diff --git a/multimodal/YOLOX/yolox/utils/checkpoint.py b/multimodal/YOLOX/yolox/utils/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0c200e41da9ad8b720369a2181c9642724622ca
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/checkpoint.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+import os
+import shutil
+from loguru import logger
+
+import torch
+
+
+def load_ckpt(model, ckpt):
+    model_state_dict = model.state_dict()
+    load_dict = {}
+    for key_model, v in model_state_dict.items():
+        if key_model not in ckpt:
+            logger.warning(
+                "{} is not in the ckpt. Please double check and see if this is desired.".format(
+                    key_model
+                )
+            )
+            continue
+        v_ckpt = ckpt[key_model]
+        if v.shape != v_ckpt.shape:
+            logger.warning(
+                "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
+                    key_model, v_ckpt.shape, key_model, v.shape
+                )
+            )
+            continue
+        load_dict[key_model] = v_ckpt
+
+    model.load_state_dict(load_dict, strict=False)
+    return model
+
+
+def save_checkpoint(state, is_best, save_dir, model_name=""):
+    if not os.path.exists(save_dir):
+        os.makedirs(save_dir)
+    filename = os.path.join(save_dir, model_name + "_ckpt.pth")
+    torch.save(state, filename)
+    if is_best:
+        best_filename = os.path.join(save_dir, "best_ckpt.pth")
+        shutil.copyfile(filename, best_filename)
diff --git a/multimodal/YOLOX/yolox/utils/compat.py b/multimodal/YOLOX/yolox/utils/compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..1324077e67215451aa8351f47f5112cd0e5e1018
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/compat.py
@@ -0,0 +1,15 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+
+_TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
+
+__all__ = ["meshgrid"]
+
+
+def meshgrid(*tensors):
+    if _TORCH_VER >= [1, 10]:
+        return torch.meshgrid(*tensors, indexing="ij")
+    else:
+        return torch.meshgrid(*tensors)
diff --git a/multimodal/YOLOX/yolox/utils/demo_utils.py b/multimodal/YOLOX/yolox/utils/demo_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..56dd33686f03c4ec1b82a79e3dadcd49fec6c0bb
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/demo_utils.py
@@ -0,0 +1,159 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import os
+import random
+
+import cv2
+import numpy as np
+
+__all__ = [
+    "mkdir", "nms", "multiclass_nms", "demo_postprocess", "random_color", "visualize_assign"
+]
+
+
+def random_color():
+    return random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)
+
+
+def visualize_assign(img, boxes, coords, match_results, save_name=None) -> np.ndarray:
+    """visualize label assign result.
+
+    Args:
+        img: img to visualize
+        boxes: gt boxes in xyxy format
+        coords: coords of matched anchors
+        match_results: match results of each gt box and coord.
+        save_name: name of save image, if None, image will not be saved. Default: None.
+    """
+    for box_id, box in enumerate(boxes):
+        x1, y1, x2, y2 = box
+        color = random_color()
+        assign_coords = coords[match_results == box_id]
+        if assign_coords.numel() == 0:
+            # unmatched boxes are red
+            color = (0, 0, 255)
+            cv2.putText(
+                img, "unmatched", (int(x1), int(y1) - 5),
+                cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 1
+            )
+        else:
+            for coord in assign_coords:
+                # draw assigned anchor
+                cv2.circle(img, (int(coord[0]), int(coord[1])), 3, color, -1)
+        cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
+
+    if save_name is not None:
+        cv2.imwrite(save_name, img)
+
+    return img
+
+
+def mkdir(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def nms(boxes, scores, nms_thr):
+    """Single class NMS implemented in Numpy."""
+    x1 = boxes[:, 0]
+    y1 = boxes[:, 1]
+    x2 = boxes[:, 2]
+    y2 = boxes[:, 3]
+
+    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(0.0, xx2 - xx1 + 1)
+        h = np.maximum(0.0, yy2 - yy1 + 1)
+        inter = w * h
+        ovr = inter / (areas[i] + areas[order[1:]] - inter)
+
+        inds = np.where(ovr <= nms_thr)[0]
+        order = order[inds + 1]
+
+    return keep
+
+
+def multiclass_nms(boxes, scores, nms_thr, score_thr, class_agnostic=True):
+    """Multiclass NMS implemented in Numpy"""
+    if class_agnostic:
+        nms_method = multiclass_nms_class_agnostic
+    else:
+        nms_method = multiclass_nms_class_aware
+    return nms_method(boxes, scores, nms_thr, score_thr)
+
+
+def multiclass_nms_class_aware(boxes, scores, nms_thr, score_thr):
+    """Multiclass NMS implemented in Numpy. Class-aware version."""
+    final_dets = []
+    num_classes = scores.shape[1]
+    for cls_ind in range(num_classes):
+        cls_scores = scores[:, cls_ind]
+        valid_score_mask = cls_scores > score_thr
+        if valid_score_mask.sum() == 0:
+            continue
+        else:
+            valid_scores = cls_scores[valid_score_mask]
+            valid_boxes = boxes[valid_score_mask]
+            keep = nms(valid_boxes, valid_scores, nms_thr)
+            if len(keep) > 0:
+                cls_inds = np.ones((len(keep), 1)) * cls_ind
+                dets = np.concatenate(
+                    [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
+                )
+                final_dets.append(dets)
+    if len(final_dets) == 0:
+        return None
+    return np.concatenate(final_dets, 0)
+
+
+def multiclass_nms_class_agnostic(boxes, scores, nms_thr, score_thr):
+    """Multiclass NMS implemented in Numpy. Class-agnostic version."""
+    cls_inds = scores.argmax(1)
+    cls_scores = scores[np.arange(len(cls_inds)), cls_inds]
+
+    valid_score_mask = cls_scores > score_thr
+    if valid_score_mask.sum() == 0:
+        return None
+    valid_scores = cls_scores[valid_score_mask]
+    valid_boxes = boxes[valid_score_mask]
+    valid_cls_inds = cls_inds[valid_score_mask]
+    keep = nms(valid_boxes, valid_scores, nms_thr)
+    if keep:
+        dets = np.concatenate(
+            [valid_boxes[keep], valid_scores[keep, None], valid_cls_inds[keep, None]], 1
+        )
+    return dets
+
+
+def demo_postprocess(outputs, img_size, p6=False):
+    grids = []
+    expanded_strides = []
+    strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
+
+    hsizes = [img_size[0] // stride for stride in strides]
+    wsizes = [img_size[1] // stride for stride in strides]
+
+    for hsize, wsize, stride in zip(hsizes, wsizes, strides):
+        xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
+        grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
+        grids.append(grid)
+        shape = grid.shape[:2]
+        expanded_strides.append(np.full((*shape, 1), stride))
+
+    grids = np.concatenate(grids, 1)
+    expanded_strides = np.concatenate(expanded_strides, 1)
+    outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
+    outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
+
+    return outputs
diff --git a/multimodal/YOLOX/yolox/utils/dist.py b/multimodal/YOLOX/yolox/utils/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8fea93346f2b52270c07ba61f2cc17c3c07047
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/dist.py
@@ -0,0 +1,294 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# This file mainly comes from
+# https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/comm.py
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Megvii Inc. All rights reserved.
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import os
+import pickle
+import time
+from contextlib import contextmanager
+from loguru import logger
+
+import numpy as np
+
+import torch
+from torch import distributed as dist
+
+__all__ = [
+    "get_num_devices",
+    "wait_for_the_master",
+    "is_main_process",
+    "synchronize",
+    "get_world_size",
+    "get_rank",
+    "get_local_rank",
+    "get_local_size",
+    "time_synchronized",
+    "gather",
+    "all_gather",
+]
+
+_LOCAL_PROCESS_GROUP = None
+
+
+def get_num_devices():
+    gpu_list = os.getenv('CUDA_VISIBLE_DEVICES', None)
+    if gpu_list is not None:
+        return len(gpu_list.split(','))
+    else:
+        devices_list_info = os.popen("nvidia-smi -L")
+        devices_list_info = devices_list_info.read().strip().split("\n")
+        return len(devices_list_info)
+
+
+@contextmanager
+def wait_for_the_master(local_rank: int = None):
+    """
+    Make all processes waiting for the master to do some task.
+
+    Args:
+        local_rank (int): the rank of the current process. Default to None.
+            If None, it will use the rank of the current process.
+    """
+    if local_rank is None:
+        local_rank = get_local_rank()
+
+    if local_rank > 0:
+        dist.barrier()
+    yield
+    if local_rank == 0:
+        if not dist.is_available():
+            return
+        if not dist.is_initialized():
+            return
+        else:
+            dist.barrier()
+
+
+def synchronize():
+    """
+    Helper function to synchronize (barrier) among all processes when using distributed training
+    """
+    if not dist.is_available():
+        return
+    if not dist.is_initialized():
+        return
+    world_size = dist.get_world_size()
+    if world_size == 1:
+        return
+    dist.barrier()
+
+
+def get_world_size() -> int:
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank() -> int:
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def get_local_rank() -> int:
+    """
+    Returns:
+        The rank of the current process within the local (per-machine) process group.
+    """
+    if _LOCAL_PROCESS_GROUP is None:
+        return get_rank()
+
+    if not dist.is_available():
+        return 0
+    if not dist.is_initialized():
+        return 0
+    return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+    """
+    Returns:
+        The size of the per-machine process group, i.e. the number of processes per machine.
+    """
+    if not dist.is_available():
+        return 1
+    if not dist.is_initialized():
+        return 1
+    return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+    return get_rank() == 0
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+    """
+    Return a process group based on gloo backend, containing all the ranks
+    The result is cached.
+    """
+    if dist.get_backend() == "nccl":
+        return dist.new_group(backend="gloo")
+    else:
+        return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+    backend = dist.get_backend(group)
+    assert backend in ["gloo", "nccl"]
+    device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+    buffer = pickle.dumps(data)
+    if len(buffer) > 1024 ** 3:
+        logger.warning(
+            "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+                get_rank(), len(buffer) / (1024 ** 3), device
+            )
+        )
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to(device=device)
+    return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+    """
+    Returns:
+        list[int]: size of the tensor, on each rank
+        Tensor: padded tensor that has the max size
+    """
+    world_size = dist.get_world_size(group=group)
+    assert (
+        world_size >= 1
+    ), "comm.gather/all_gather must be called from ranks within the given group!"
+    local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+    size_list = [
+        torch.zeros([1], dtype=torch.int64, device=tensor.device)
+        for _ in range(world_size)
+    ]
+    dist.all_gather(size_list, local_size, group=group)
+    size_list = [int(size.item()) for size in size_list]
+
+    max_size = max(size_list)
+
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    if local_size != max_size:
+        padding = torch.zeros(
+            (max_size - local_size,), dtype=torch.uint8, device=tensor.device
+        )
+        tensor = torch.cat((tensor, padding), dim=0)
+    return size_list, tensor
+
+
+def all_gather(data, group=None):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+    Args:
+        data: any picklable object
+        group: a torch process group. By default, will use a group which
+            contains all ranks on gloo backend.
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    if get_world_size() == 1:
+        return [data]
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group) == 1:
+        return [data]
+
+    tensor = _serialize_to_tensor(data, group)
+
+    size_list, tensor = _pad_to_largest_tensor(tensor, group)
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    tensor_list = [
+        torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
+        for _ in size_list
+    ]
+    dist.all_gather(tensor_list, tensor, group=group)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+
+def gather(data, dst=0, group=None):
+    """
+    Run gather on arbitrary picklable data (not necessarily tensors).
+
+    Args:
+        data: any picklable object
+        dst (int): destination rank
+        group: a torch process group. By default, will use a group which
+            contains all ranks on gloo backend.
+
+    Returns:
+        list[data]: on dst, a list of data gathered from each rank. Otherwise,
+            an empty list.
+    """
+    if get_world_size() == 1:
+        return [data]
+    if group is None:
+        group = _get_global_gloo_group()
+    if dist.get_world_size(group=group) == 1:
+        return [data]
+    rank = dist.get_rank(group=group)
+
+    tensor = _serialize_to_tensor(data, group)
+    size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+    # receiving Tensor from all ranks
+    if rank == dst:
+        max_size = max(size_list)
+        tensor_list = [
+            torch.empty((max_size,), dtype=torch.uint8, device=tensor.device)
+            for _ in size_list
+        ]
+        dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+        data_list = []
+        for size, tensor in zip(size_list, tensor_list):
+            buffer = tensor.cpu().numpy().tobytes()[:size]
+            data_list.append(pickle.loads(buffer))
+        return data_list
+    else:
+        dist.gather(tensor, [], dst=dst, group=group)
+        return []
+
+
+def shared_random_seed():
+    """
+    Returns:
+        int: a random number that is the same across all workers.
+            If workers need a shared RNG, they can use this shared seed to
+            create one.
+    All workers must call this function, otherwise it will deadlock.
+    """
+    ints = np.random.randint(2 ** 31)
+    all_ints = all_gather(ints)
+    return all_ints[0]
+
+
+def time_synchronized():
+    """pytorch-accurate time"""
+    if torch.cuda.is_available():
+        torch.cuda.synchronize()
+    return time.time()
diff --git a/multimodal/YOLOX/yolox/utils/ema.py b/multimodal/YOLOX/yolox/utils/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..73acbca6796d3cdd07397e657167acdbd5a57647
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/ema.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+import math
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+
+__all__ = ["ModelEMA", "is_parallel"]
+
+
+def is_parallel(model):
+    """check if model is in parallel mode."""
+    parallel_type = (
+        nn.parallel.DataParallel,
+        nn.parallel.DistributedDataParallel,
+    )
+    return isinstance(model, parallel_type)
+
+
+class ModelEMA:
+    """
+    Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
+    Keep a moving average of everything in the model state_dict (parameters and buffers).
+    This is intended to allow functionality like
+    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+    A smoothed version of the weights is necessary for some training schemes to perform well.
+    This class is sensitive where it is initialized in the sequence of model init,
+    GPU assignment and distributed training wrappers.
+    """
+
+    def __init__(self, model, decay=0.9999, updates=0):
+        """
+        Args:
+            model (nn.Module): model to apply EMA.
+            decay (float): ema decay reate.
+            updates (int): counter of EMA updates.
+        """
+        # Create EMA(FP32)
+        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()
+        self.updates = updates
+        # decay exponential ramp (to help early epochs)
+        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+    def update(self, model):
+        # Update EMA parameters
+        with torch.no_grad():
+            self.updates += 1
+            d = self.decay(self.updates)
+
+            msd = (
+                model.module.state_dict() if is_parallel(model) else model.state_dict()
+            )  # model state_dict
+            for k, v in self.ema.state_dict().items():
+                if v.dtype.is_floating_point:
+                    v *= d
+                    v += (1.0 - d) * msd[k].detach()
diff --git a/multimodal/YOLOX/yolox/utils/logger.py b/multimodal/YOLOX/yolox/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..1045a7b47c579041b3cef5c9a408a210caa5e64f
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/logger.py
@@ -0,0 +1,440 @@
+#!/usr/bin/env python3
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import inspect
+import os
+import sys
+from collections import defaultdict
+from loguru import logger
+
+import cv2
+import numpy as np
+
+import torch
+
+
+def get_caller_name(depth=0):
+    """
+    Args:
+        depth (int): Depth of caller conext, use 0 for caller depth.
+        Default value: 0.
+
+    Returns:
+        str: module name of the caller
+    """
+    # the following logic is a little bit faster than inspect.stack() logic
+    frame = inspect.currentframe().f_back
+    for _ in range(depth):
+        frame = frame.f_back
+
+    return frame.f_globals["__name__"]
+
+
+class StreamToLoguru:
+    """
+    stream object that redirects writes to a logger instance.
+    """
+
+    def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
+        """
+        Args:
+            level(str): log level string of loguru. Default value: "INFO".
+            caller_names(tuple): caller names of redirected module.
+                Default value: (apex, pycocotools).
+        """
+        self.level = level
+        self.linebuf = ""
+        self.caller_names = caller_names
+
+    def write(self, buf):
+        full_name = get_caller_name(depth=1)
+        module_name = full_name.rsplit(".", maxsplit=-1)[0]
+        if module_name in self.caller_names:
+            for line in buf.rstrip().splitlines():
+                # use caller level log
+                logger.opt(depth=2).log(self.level, line.rstrip())
+        else:
+            sys.__stdout__.write(buf)
+
+    def flush(self):
+        # flush is related with CPR(cursor position report) in terminal
+        return sys.__stdout__.flush()
+
+    def isatty(self):
+        # when using colab, jax is installed by default and issue like
+        # https://github.com/Megvii-BaseDetection/YOLOX/issues/1437 might be raised
+        # due to missing attribute like`isatty`.
+        # For more details, checked the following link:
+        # https://github.com/google/jax/blob/10720258ea7fb5bde997dfa2f3f71135ab7a6733/jax/_src/pretty_printer.py#L54  # noqa
+        return sys.__stdout__.isatty()
+
+    def fileno(self):
+        # To solve the issue when using debug tools like pdb
+        return sys.__stdout__.fileno()
+
+
+def redirect_sys_output(log_level="INFO"):
+    redirect_logger = StreamToLoguru(log_level)
+    sys.stderr = redirect_logger
+    sys.stdout = redirect_logger
+
+
+def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
+    """setup logger for training and testing.
+    Args:
+        save_dir(str): location to save log file
+        distributed_rank(int): device rank when multi-gpu environment
+        filename (string): log save name.
+        mode(str): log file write mode, `append` or `override`. default is `a`.
+
+    Return:
+        logger instance.
+    """
+    loguru_format = (
+        "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
+        "<level>{level: <8}</level> | "
+        "<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
+    )
+
+    logger.remove()
+    save_file = os.path.join(save_dir, filename)
+    if mode == "o" and os.path.exists(save_file):
+        os.remove(save_file)
+    # only keep logger in rank0 process
+    if distributed_rank == 0:
+        logger.add(
+            sys.stderr,
+            format=loguru_format,
+            level="INFO",
+            enqueue=True,
+        )
+        logger.add(save_file)
+
+    # redirect stdout/stderr to loguru
+    redirect_sys_output("INFO")
+
+
+class WandbLogger(object):
+    """
+    Log training runs, datasets, models, and predictions to Weights & Biases.
+    This logger sends information to W&B at wandb.ai.
+    By default, this information includes hyperparameters,
+    system configuration and metrics, model metrics,
+    and basic data metrics and analyses.
+
+    For more information, please refer to:
+    https://docs.wandb.ai/guides/track
+    https://docs.wandb.ai/guides/integrations/other/yolox
+    """
+    def __init__(self,
+                 project=None,
+                 name=None,
+                 id=None,
+                 entity=None,
+                 save_dir=None,
+                 config=None,
+                 val_dataset=None,
+                 num_eval_images=100,
+                 log_checkpoints=False,
+                 **kwargs):
+        """
+        Args:
+            project (str): wandb project name.
+            name (str): wandb run name.
+            id (str): wandb run id.
+            entity (str): wandb entity name.
+            save_dir (str): save directory.
+            config (dict): config dict.
+            val_dataset (Dataset): validation dataset.
+            num_eval_images (int): number of images from the validation set to log.
+            log_checkpoints (bool): log checkpoints
+            **kwargs: other kwargs.
+
+        Usage:
+            Any arguments for wandb.init can be provided on the command line using
+            the prefix `wandb-`.
+            Example
+            ```
+            python tools/train.py .... --logger wandb wandb-project <project-name> \
+                wandb-name <run-name> \
+                wandb-id <run-id> \
+                wandb-save_dir <save-dir> \
+                wandb-num_eval_imges <num-images> \
+                wandb-log_checkpoints <bool>
+            ```
+            The val_dataset argument is not open to the command line.
+        """
+        try:
+            import wandb
+            self.wandb = wandb
+        except ModuleNotFoundError:
+            raise ModuleNotFoundError(
+                "wandb is not installed."
+                "Please install wandb using pip install wandb"
+                )
+
+        from yolox.data.datasets import VOCDetection
+
+        self.project = project
+        self.name = name
+        self.id = id
+        self.save_dir = save_dir
+        self.config = config
+        self.kwargs = kwargs
+        self.entity = entity
+        self._run = None
+        self.val_artifact = None
+        if num_eval_images == -1:
+            self.num_log_images = len(val_dataset)
+        else:
+            self.num_log_images = min(num_eval_images, len(val_dataset))
+        self.log_checkpoints = (log_checkpoints == "True" or log_checkpoints == "true")
+        self._wandb_init = dict(
+            project=self.project,
+            name=self.name,
+            id=self.id,
+            entity=self.entity,
+            dir=self.save_dir,
+            resume="allow"
+        )
+        self._wandb_init.update(**kwargs)
+
+        _ = self.run
+
+        if self.config:
+            self.run.config.update(self.config)
+        self.run.define_metric("train/epoch")
+        self.run.define_metric("val/*", step_metric="train/epoch")
+        self.run.define_metric("train/step")
+        self.run.define_metric("train/*", step_metric="train/step")
+
+        self.voc_dataset = VOCDetection
+
+        if val_dataset and self.num_log_images != 0:
+            self.val_dataset = val_dataset
+            self.cats = val_dataset.cats
+            self.id_to_class = {
+                cls['id']: cls['name'] for cls in self.cats
+            }
+            self._log_validation_set(val_dataset)
+
+    @property
+    def run(self):
+        if self._run is None:
+            if self.wandb.run is not None:
+                logger.info(
+                    "There is a wandb run already in progress "
+                    "and newly created instances of `WandbLogger` will reuse"
+                    " this run. If this is not desired, call `wandb.finish()`"
+                    "before instantiating `WandbLogger`."
+                )
+                self._run = self.wandb.run
+            else:
+                self._run = self.wandb.init(**self._wandb_init)
+        return self._run
+
+    def _log_validation_set(self, val_dataset):
+        """
+        Log validation set to wandb.
+
+        Args:
+            val_dataset (Dataset): validation dataset.
+        """
+        if self.val_artifact is None:
+            self.val_artifact = self.wandb.Artifact(name="validation_images", type="dataset")
+            self.val_table = self.wandb.Table(columns=["id", "input"])
+
+            for i in range(self.num_log_images):
+                data_point = val_dataset[i]
+                img = data_point[0]
+                id = data_point[3]
+                img = np.transpose(img, (1, 2, 0))
+                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+                if isinstance(id, torch.Tensor):
+                    id = id.item()
+
+                self.val_table.add_data(
+                    id,
+                    self.wandb.Image(img)
+                )
+
+            self.val_artifact.add(self.val_table, "validation_images_table")
+            self.run.use_artifact(self.val_artifact)
+            self.val_artifact.wait()
+
+    def _convert_prediction_format(self, predictions):
+        image_wise_data = defaultdict(int)
+
+        for key, val in predictions.items():
+            img_id = key
+
+            try:
+                bboxes, cls, scores = val
+            except KeyError:
+                bboxes, cls, scores = val["bboxes"], val["categories"], val["scores"]
+
+            # These store information of actual bounding boxes i.e. the ones which are not None
+            act_box = []
+            act_scores = []
+            act_cls = []
+
+            if bboxes is not None:
+                for box, classes, score in zip(bboxes, cls, scores):
+                    if box is None or score is None or classes is None:
+                        continue
+                    act_box.append(box)
+                    act_scores.append(score)
+                    act_cls.append(classes)
+
+            image_wise_data.update({
+                int(img_id): {
+                    "bboxes": [box.numpy().tolist() for box in act_box],
+                    "scores": [score.numpy().item() for score in act_scores],
+                    "categories": [
+                        self.val_dataset.class_ids[int(act_cls[ind])]
+                        for ind in range(len(act_box))
+                    ],
+                }
+            })
+
+        return image_wise_data
+
+    def log_metrics(self, metrics, step=None):
+        """
+        Args:
+            metrics (dict): metrics dict.
+            step (int): step number.
+        """
+
+        for k, v in metrics.items():
+            if isinstance(v, torch.Tensor):
+                metrics[k] = v.item()
+
+        if step is not None:
+            metrics.update({"train/step": step})
+            self.run.log(metrics)
+        else:
+            self.run.log(metrics)
+
+    def log_images(self, predictions):
+        if len(predictions) == 0 or self.val_artifact is None or self.num_log_images == 0:
+            return
+
+        table_ref = self.val_artifact.get("validation_images_table")
+
+        columns = ["id", "predicted"]
+        for cls in self.cats:
+            columns.append(cls["name"])
+
+        if isinstance(self.val_dataset, self.voc_dataset):
+            predictions = self._convert_prediction_format(predictions)
+
+        result_table = self.wandb.Table(columns=columns)
+
+        for idx, val in table_ref.iterrows():
+
+            avg_scores = defaultdict(int)
+            num_occurrences = defaultdict(int)
+
+            id = val[0]
+            if isinstance(id, list):
+                id = id[0]
+
+            if id in predictions:
+                prediction = predictions[id]
+                boxes = []
+                for i in range(len(prediction["bboxes"])):
+                    bbox = prediction["bboxes"][i]
+                    x0 = bbox[0]
+                    y0 = bbox[1]
+                    x1 = bbox[2]
+                    y1 = bbox[3]
+                    box = {
+                        "position": {
+                            "minX": min(x0, x1),
+                            "minY": min(y0, y1),
+                            "maxX": max(x0, x1),
+                            "maxY": max(y0, y1)
+                        },
+                        "class_id": prediction["categories"][i],
+                        "domain": "pixel"
+                    }
+                    avg_scores[
+                        self.id_to_class[prediction["categories"][i]]
+                    ] += prediction["scores"][i]
+                    num_occurrences[self.id_to_class[prediction["categories"][i]]] += 1
+                    boxes.append(box)
+            else:
+                boxes = []
+            average_class_score = []
+            for cls in self.cats:
+                if cls["name"] not in num_occurrences:
+                    score = 0
+                else:
+                    score = avg_scores[cls["name"]] / num_occurrences[cls["name"]]
+                average_class_score.append(score)
+            result_table.add_data(
+                idx,
+                self.wandb.Image(val[1], boxes={
+                        "prediction": {
+                            "box_data": boxes,
+                            "class_labels": self.id_to_class
+                        }
+                    }
+                ),
+                *average_class_score
+            )
+
+        self.wandb.log({"val_results/result_table": result_table})
+
+    def save_checkpoint(self, save_dir, model_name, is_best, metadata=None):
+        """
+        Args:
+            save_dir (str): save directory.
+            model_name (str): model name.
+            is_best (bool): whether the model is the best model.
+            metadata (dict): metadata to save corresponding to the checkpoint.
+        """
+
+        if not self.log_checkpoints:
+            return
+
+        if "epoch" in metadata:
+            epoch = metadata["epoch"]
+        else:
+            epoch = None
+
+        filename = os.path.join(save_dir, model_name + "_ckpt.pth")
+        artifact = self.wandb.Artifact(
+            name=f"run_{self.run.id}_model",
+            type="model",
+            metadata=metadata
+        )
+        artifact.add_file(filename, name="model_ckpt.pth")
+
+        aliases = ["latest"]
+
+        if is_best:
+            aliases.append("best")
+
+        if epoch:
+            aliases.append(f"epoch-{epoch}")
+
+        self.run.log_artifact(artifact, aliases=aliases)
+
+    def finish(self):
+        self.run.finish()
+
+    @classmethod
+    def initialize_wandb_logger(cls, args, exp, val_dataset):
+        wandb_params = dict()
+        prefix = "wandb-"
+        for k, v in zip(args.opts[0::2], args.opts[1::2]):
+            if k.startswith("wandb-"):
+                try:
+                    wandb_params.update({k[len(prefix):]: int(v)})
+                except ValueError:
+                    wandb_params.update({k[len(prefix):]: v})
+
+        return cls(config=vars(exp), val_dataset=val_dataset, **wandb_params)
diff --git a/multimodal/YOLOX/yolox/utils/lr_scheduler.py b/multimodal/YOLOX/yolox/utils/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c00cf23281ac370957fccb062635b36dede8ea
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/lr_scheduler.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import math
+from functools import partial
+
+
+class LRScheduler:
+    def __init__(self, name, lr, iters_per_epoch, total_epochs, **kwargs):
+        """
+        Supported lr schedulers: [cos, warmcos, multistep]
+
+        Args:
+            lr (float): learning rate.
+            iters_per_epoch (int): number of iterations in one epoch.
+            total_epochs (int): number of epochs in training.
+            kwargs (dict):
+                - cos: None
+                - warmcos: [warmup_epochs, warmup_lr_start (default 1e-6)]
+                - multistep: [milestones (epochs), gamma (default 0.1)]
+        """
+
+        self.lr = lr
+        self.iters_per_epoch = iters_per_epoch
+        self.total_epochs = total_epochs
+        self.total_iters = iters_per_epoch * total_epochs
+
+        self.__dict__.update(kwargs)
+
+        self.lr_func = self._get_lr_func(name)
+
+    def update_lr(self, iters):
+        return self.lr_func(iters)
+
+    def _get_lr_func(self, name):
+        if name == "cos":  # cosine lr schedule
+            lr_func = partial(cos_lr, self.lr, self.total_iters)
+        elif name == "warmcos":
+            warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
+            warmup_lr_start = getattr(self, "warmup_lr_start", 1e-6)
+            lr_func = partial(
+                warm_cos_lr,
+                self.lr,
+                self.total_iters,
+                warmup_total_iters,
+                warmup_lr_start,
+            )
+        elif name == "yoloxwarmcos":
+            warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
+            no_aug_iters = self.iters_per_epoch * self.no_aug_epochs
+            warmup_lr_start = getattr(self, "warmup_lr_start", 0)
+            min_lr_ratio = getattr(self, "min_lr_ratio", 0.2)
+            lr_func = partial(
+                yolox_warm_cos_lr,
+                self.lr,
+                min_lr_ratio,
+                self.total_iters,
+                warmup_total_iters,
+                warmup_lr_start,
+                no_aug_iters,
+            )
+        elif name == "yoloxsemiwarmcos":
+            warmup_lr_start = getattr(self, "warmup_lr_start", 0)
+            min_lr_ratio = getattr(self, "min_lr_ratio", 0.2)
+            warmup_total_iters = self.iters_per_epoch * self.warmup_epochs
+            no_aug_iters = self.iters_per_epoch * self.no_aug_epochs
+            normal_iters = self.iters_per_epoch * self.semi_epoch
+            semi_iters = self.iters_per_epoch_semi * (
+                self.total_epochs - self.semi_epoch - self.no_aug_epochs
+            )
+            lr_func = partial(
+                yolox_semi_warm_cos_lr,
+                self.lr,
+                min_lr_ratio,
+                warmup_lr_start,
+                self.total_iters,
+                normal_iters,
+                no_aug_iters,
+                warmup_total_iters,
+                semi_iters,
+                self.iters_per_epoch,
+                self.iters_per_epoch_semi,
+            )
+        elif name == "multistep":  # stepwise lr schedule
+            milestones = [
+                int(self.total_iters * milestone / self.total_epochs)
+                for milestone in self.milestones
+            ]
+            gamma = getattr(self, "gamma", 0.1)
+            lr_func = partial(multistep_lr, self.lr, milestones, gamma)
+        else:
+            raise ValueError("Scheduler version {} not supported.".format(name))
+        return lr_func
+
+
+def cos_lr(lr, total_iters, iters):
+    """Cosine learning rate"""
+    lr *= 0.5 * (1.0 + math.cos(math.pi * iters / total_iters))
+    return lr
+
+
+def warm_cos_lr(lr, total_iters, warmup_total_iters, warmup_lr_start, iters):
+    """Cosine learning rate with warm up."""
+    if iters <= warmup_total_iters:
+        lr = (lr - warmup_lr_start) * iters / float(
+            warmup_total_iters
+        ) + warmup_lr_start
+    else:
+        lr *= 0.5 * (
+            1.0
+            + math.cos(
+                math.pi
+                * (iters - warmup_total_iters)
+                / (total_iters - warmup_total_iters)
+            )
+        )
+    return lr
+
+
+def yolox_warm_cos_lr(
+    lr,
+    min_lr_ratio,
+    total_iters,
+    warmup_total_iters,
+    warmup_lr_start,
+    no_aug_iter,
+    iters,
+):
+    """Cosine learning rate with warm up."""
+    min_lr = lr * min_lr_ratio
+    if iters <= warmup_total_iters:
+        # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
+        lr = (lr - warmup_lr_start) * pow(
+            iters / float(warmup_total_iters), 2
+        ) + warmup_lr_start
+    elif iters >= total_iters - no_aug_iter:
+        lr = min_lr
+    else:
+        lr = min_lr + 0.5 * (lr - min_lr) * (
+            1.0
+            + math.cos(
+                math.pi
+                * (iters - warmup_total_iters)
+                / (total_iters - warmup_total_iters - no_aug_iter)
+            )
+        )
+    return lr
+
+
+def yolox_semi_warm_cos_lr(
+    lr,
+    min_lr_ratio,
+    warmup_lr_start,
+    total_iters,
+    normal_iters,
+    no_aug_iters,
+    warmup_total_iters,
+    semi_iters,
+    iters_per_epoch,
+    iters_per_epoch_semi,
+    iters,
+):
+    """Cosine learning rate with warm up."""
+    min_lr = lr * min_lr_ratio
+    if iters <= warmup_total_iters:
+        # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
+        lr = (lr - warmup_lr_start) * pow(
+            iters / float(warmup_total_iters), 2
+        ) + warmup_lr_start
+    elif iters >= normal_iters + semi_iters:
+        lr = min_lr
+    elif iters <= normal_iters:
+        lr = min_lr + 0.5 * (lr - min_lr) * (
+            1.0
+            + math.cos(
+                math.pi
+                * (iters - warmup_total_iters)
+                / (total_iters - warmup_total_iters - no_aug_iters)
+            )
+        )
+    else:
+        lr = min_lr + 0.5 * (lr - min_lr) * (
+            1.0
+            + math.cos(
+                math.pi
+                * (
+                    normal_iters
+                    - warmup_total_iters
+                    + (iters - normal_iters)
+                    * iters_per_epoch
+                    * 1.0
+                    / iters_per_epoch_semi
+                )
+                / (total_iters - warmup_total_iters - no_aug_iters)
+            )
+        )
+    return lr
+
+
+def multistep_lr(lr, milestones, gamma, iters):
+    """MultiStep learning rate"""
+    for milestone in milestones:
+        lr *= gamma if iters >= milestone else 1.0
+    return lr
diff --git a/multimodal/YOLOX/yolox/utils/metric.py b/multimodal/YOLOX/yolox/utils/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..506b58281896ade91184e5a34d677f1b185a31fe
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/metric.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+import functools
+import os
+import time
+from collections import defaultdict, deque
+import psutil
+
+import numpy as np
+
+import torch
+
+__all__ = [
+    "AverageMeter",
+    "MeterBuffer",
+    "get_total_and_free_memory_in_Mb",
+    "occupy_mem",
+    "gpu_mem_usage",
+    "mem_usage"
+]
+
+
+def get_total_and_free_memory_in_Mb(cuda_device):
+    devices_info_str = os.popen(
+        "nvidia-smi --query-gpu=memory.total,memory.used --format=csv,nounits,noheader"
+    )
+    devices_info = devices_info_str.read().strip().split("\n")
+    if "CUDA_VISIBLE_DEVICES" in os.environ:
+        visible_devices = os.environ["CUDA_VISIBLE_DEVICES"].split(',')
+        cuda_device = int(visible_devices[cuda_device])
+    total, used = devices_info[int(cuda_device)].split(",")
+    return int(total), int(used)
+
+
+def occupy_mem(cuda_device, mem_ratio=0.9):
+    """
+    pre-allocate gpu memory for training to avoid memory Fragmentation.
+    """
+    total, used = get_total_and_free_memory_in_Mb(cuda_device)
+    max_mem = int(total * mem_ratio)
+    block_mem = max_mem - used
+    x = torch.cuda.FloatTensor(256, 1024, block_mem)
+    del x
+    time.sleep(5)
+
+
+def gpu_mem_usage():
+    """
+    Compute the GPU memory usage for the current device (MB).
+    """
+    mem_usage_bytes = torch.cuda.max_memory_allocated()
+    return mem_usage_bytes / (1024 * 1024)
+
+
+def mem_usage():
+    """
+    Compute the memory usage for the current machine (GB).
+    """
+    gb = 1 << 30
+    mem = psutil.virtual_memory()
+    return mem.used / gb
+
+
+class AverageMeter:
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=50):
+        self._deque = deque(maxlen=window_size)
+        self._total = 0.0
+        self._count = 0
+
+    def update(self, value):
+        self._deque.append(value)
+        self._count += 1
+        self._total += value
+
+    @property
+    def median(self):
+        d = np.array(list(self._deque))
+        return np.median(d)
+
+    @property
+    def avg(self):
+        # if deque is empty, nan will be returned.
+        d = np.array(list(self._deque))
+        return d.mean()
+
+    @property
+    def global_avg(self):
+        return self._total / max(self._count, 1e-5)
+
+    @property
+    def latest(self):
+        return self._deque[-1] if len(self._deque) > 0 else None
+
+    @property
+    def total(self):
+        return self._total
+
+    def reset(self):
+        self._deque.clear()
+        self._total = 0.0
+        self._count = 0
+
+    def clear(self):
+        self._deque.clear()
+
+
+class MeterBuffer(defaultdict):
+    """Computes and stores the average and current value"""
+
+    def __init__(self, window_size=20):
+        factory = functools.partial(AverageMeter, window_size=window_size)
+        super().__init__(factory)
+
+    def reset(self):
+        for v in self.values():
+            v.reset()
+
+    def get_filtered_meter(self, filter_key="time"):
+        return {k: v for k, v in self.items() if filter_key in k}
+
+    def update(self, values=None, **kwargs):
+        if values is None:
+            values = {}
+        values.update(kwargs)
+        for k, v in values.items():
+            if isinstance(v, torch.Tensor):
+                v = v.detach()
+            self[k].update(v)
+
+    def clear_meters(self):
+        for v in self.values():
+            v.clear()
diff --git a/multimodal/YOLOX/yolox/utils/model_utils.py b/multimodal/YOLOX/yolox/utils/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bc2d1ff7a314e143ec3424a0afefc73b7b5b137
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/model_utils.py
@@ -0,0 +1,186 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import contextlib
+from copy import deepcopy
+from typing import Sequence
+
+import torch
+import torch.nn as nn
+
+__all__ = [
+    "fuse_conv_and_bn",
+    "fuse_model",
+    "get_model_info",
+    "replace_module",
+    "freeze_module",
+    "adjust_status",
+]
+
+
+def get_model_info(model: nn.Module, tsize: Sequence[int]) -> str:
+    from thop import profile
+
+    stride = 64
+    img = torch.zeros((1, 3, stride, stride), device=next(model.parameters()).device)
+    flops, params = profile(deepcopy(model), inputs=(img,), verbose=False)
+    params /= 1e6
+    flops /= 1e9
+    flops *= tsize[0] * tsize[1] / stride / stride * 2  # Gflops
+    info = "Params: {:.2f}M, Gflops: {:.2f}".format(params, flops)
+    return info
+
+
+def fuse_conv_and_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d:
+    """
+    Fuse convolution and batchnorm layers.
+    check more info on https://tehnokv.com/posts/fusing-batchnorm-and-conv/
+
+    Args:
+        conv (nn.Conv2d): convolution to fuse.
+        bn (nn.BatchNorm2d): batchnorm to fuse.
+
+    Returns:
+        nn.Conv2d: fused convolution behaves the same as the input conv and bn.
+    """
+    fusedconv = (
+        nn.Conv2d(
+            conv.in_channels,
+            conv.out_channels,
+            kernel_size=conv.kernel_size,
+            stride=conv.stride,
+            padding=conv.padding,
+            groups=conv.groups,
+            bias=True,
+        )
+        .requires_grad_(False)
+        .to(conv.weight.device)
+    )
+
+    # prepare filters
+    w_conv = conv.weight.clone().view(conv.out_channels, -1)
+    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+    fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
+
+    # prepare spatial bias
+    b_conv = (
+        torch.zeros(conv.weight.size(0), device=conv.weight.device)
+        if conv.bias is None
+        else conv.bias
+    )
+    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
+        torch.sqrt(bn.running_var + bn.eps)
+    )
+    fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+    return fusedconv
+
+
+def fuse_model(model: nn.Module) -> nn.Module:
+    """fuse conv and bn in model
+
+    Args:
+        model (nn.Module): model to fuse
+
+    Returns:
+        nn.Module: fused model
+    """
+    from yolox.models.network_blocks import BaseConv
+
+    for m in model.modules():
+        if type(m) is BaseConv and hasattr(m, "bn"):
+            m.conv = fuse_conv_and_bn(m.conv, m.bn)  # update conv
+            delattr(m, "bn")  # remove batchnorm
+            m.forward = m.fuseforward  # update forward
+    return model
+
+
+def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
+    """
+    Replace given type in module to a new type. mostly used in deploy.
+
+    Args:
+        module (nn.Module): model to apply replace operation.
+        replaced_module_type (Type): module type to be replaced.
+        new_module_type (Type)
+        replace_func (function): python function to describe replace logic. Defalut value None.
+
+    Returns:
+        model (nn.Module): module that already been replaced.
+    """
+
+    def default_replace_func(replaced_module_type, new_module_type):
+        return new_module_type()
+
+    if replace_func is None:
+        replace_func = default_replace_func
+
+    model = module
+    if isinstance(module, replaced_module_type):
+        model = replace_func(replaced_module_type, new_module_type)
+    else:  # recurrsively replace
+        for name, child in module.named_children():
+            new_child = replace_module(child, replaced_module_type, new_module_type)
+            if new_child is not child:  # child is already replaced
+                model.add_module(name, new_child)
+
+    return model
+
+
+def freeze_module(module: nn.Module, name=None) -> nn.Module:
+    """freeze module inplace
+
+    Args:
+        module (nn.Module): module to freeze.
+        name (str, optional): name to freeze. If not given, freeze the whole module.
+            Note that fuzzy match is not supported. Defaults to None.
+
+    Examples:
+        freeze the backbone of model
+        >>> freeze_moudle(model.backbone)
+
+        or freeze the backbone of model by name
+        >>> freeze_moudle(model, name="backbone")
+    """
+    for param_name, parameter in module.named_parameters():
+        if name is None or name in param_name:
+            parameter.requires_grad = False
+
+    # ensure module like BN and dropout are freezed
+    for module_name, sub_module in module.named_modules():
+        # actually there are no needs to call eval for every single sub_module
+        if name is None or name in module_name:
+            sub_module.eval()
+
+    return module
+
+
+@contextlib.contextmanager
+def adjust_status(module: nn.Module, training: bool = False) -> nn.Module:
+    """Adjust module to training/eval mode temporarily.
+
+    Args:
+        module (nn.Module): module to adjust status.
+        training (bool): training mode to set. True for train mode, False fro eval mode.
+
+    Examples:
+        >>> with adjust_status(model, training=False):
+        ...     model(data)
+    """
+    status = {}
+
+    def backup_status(module):
+        for m in module.modules():
+            # save prev status to dict
+            status[m] = m.training
+            m.training = training
+
+    def recover_status(module):
+        for m in module.modules():
+            # recover prev status from dict
+            m.training = status.pop(m)
+
+    backup_status(module)
+    yield module
+    recover_status(module)
diff --git a/multimodal/YOLOX/yolox/utils/setup_env.py b/multimodal/YOLOX/yolox/utils/setup_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..45289f3245f09e48395ad419d17efffe6846b05c
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/setup_env.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import os
+import subprocess
+from loguru import logger
+
+import cv2
+
+from .dist import get_world_size, is_main_process
+
+__all__ = ["configure_nccl", "configure_module", "configure_omp"]
+
+
+def configure_nccl():
+    """Configure multi-machine environment variables of NCCL."""
+    os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
+    os.environ["NCCL_IB_HCA"] = subprocess.getoutput(
+        "pushd /sys/class/infiniband/ > /dev/null; for i in mlx5_*; "
+        "do cat $i/ports/1/gid_attrs/types/* 2>/dev/null "
+        "| grep v >/dev/null && echo $i ; done; popd > /dev/null"
+    )
+    os.environ["NCCL_IB_GID_INDEX"] = "3"
+    os.environ["NCCL_IB_TC"] = "106"
+
+
+def configure_omp(num_threads=1):
+    """
+    If OMP_NUM_THREADS is not configured and world_size is greater than 1,
+    Configure OMP_NUM_THREADS environment variables of NCCL to `num_thread`.
+
+    Args:
+        num_threads (int): value of `OMP_NUM_THREADS` to set.
+    """
+    # We set OMP_NUM_THREADS=1 by default, which achieves the best speed on our machines
+    # feel free to change it for better performance.
+    if "OMP_NUM_THREADS" not in os.environ and get_world_size() > 1:
+        os.environ["OMP_NUM_THREADS"] = str(num_threads)
+        if is_main_process():
+            logger.info(
+                "\n***************************************************************\n"
+                "We set `OMP_NUM_THREADS` for each process to {} to speed up.\n"
+                "please further tune the variable for optimal performance.\n"
+                "***************************************************************".format(
+                    os.environ["OMP_NUM_THREADS"]
+                )
+            )
+
+
+def configure_module(ulimit_value=8192):
+    """
+    Configure pytorch module environment. setting of ulimit and cv2 will be set.
+
+    Args:
+        ulimit_value(int): default open file number on linux. Default value: 8192.
+    """
+    # system setting
+    try:
+        import resource
+
+        rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+        resource.setrlimit(resource.RLIMIT_NOFILE, (ulimit_value, rlimit[1]))
+    except Exception:
+        # Exception might be raised in Windows OS or rlimit reaches max limit number.
+        # However, set rlimit value might not be necessary.
+        pass
+
+    # cv2
+    # multiprocess might be harmful on performance of torch dataloader
+    os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
+    try:
+        cv2.setNumThreads(0)
+        cv2.ocl.setUseOpenCL(False)
+    except Exception:
+        # cv2 version mismatch might rasie exceptions.
+        pass
diff --git a/multimodal/YOLOX/yolox/utils/visualize.py b/multimodal/YOLOX/yolox/utils/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e714a3ee73699141fb4cd8d131d541a6e6625ed6
--- /dev/null
+++ b/multimodal/YOLOX/yolox/utils/visualize.py
@@ -0,0 +1,128 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+# Copyright (c) Megvii Inc. All rights reserved.
+
+import cv2
+import numpy as np
+
+__all__ = ["vis"]
+
+
+def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None):
+
+    for i in range(len(boxes)):
+        box = boxes[i]
+        cls_id = int(cls_ids[i])
+        score = scores[i]
+        if score < conf:
+            continue
+        x0 = int(box[0])
+        y0 = int(box[1])
+        x1 = int(box[2])
+        y1 = int(box[3])
+
+        color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist()
+        text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100)
+        txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255)
+        font = cv2.FONT_HERSHEY_SIMPLEX
+
+        txt_size = cv2.getTextSize(text, font, 0.4, 1)[0]
+        cv2.rectangle(img, (x0, y0), (x1, y1), color, 2)
+
+        txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist()
+        cv2.rectangle(
+            img,
+            (x0, y0 + 1),
+            (x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])),
+            txt_bk_color,
+            -1
+        )
+        cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1)
+
+    return img
+
+
+_COLORS = np.array(
+    [
+        0.000, 0.447, 0.741,
+        0.850, 0.325, 0.098,
+        0.929, 0.694, 0.125,
+        0.494, 0.184, 0.556,
+        0.466, 0.674, 0.188,
+        0.301, 0.745, 0.933,
+        0.635, 0.078, 0.184,
+        0.300, 0.300, 0.300,
+        0.600, 0.600, 0.600,
+        1.000, 0.000, 0.000,
+        1.000, 0.500, 0.000,
+        0.749, 0.749, 0.000,
+        0.000, 1.000, 0.000,
+        0.000, 0.000, 1.000,
+        0.667, 0.000, 1.000,
+        0.333, 0.333, 0.000,
+        0.333, 0.667, 0.000,
+        0.333, 1.000, 0.000,
+        0.667, 0.333, 0.000,
+        0.667, 0.667, 0.000,
+        0.667, 1.000, 0.000,
+        1.000, 0.333, 0.000,
+        1.000, 0.667, 0.000,
+        1.000, 1.000, 0.000,
+        0.000, 0.333, 0.500,
+        0.000, 0.667, 0.500,
+        0.000, 1.000, 0.500,
+        0.333, 0.000, 0.500,
+        0.333, 0.333, 0.500,
+        0.333, 0.667, 0.500,
+        0.333, 1.000, 0.500,
+        0.667, 0.000, 0.500,
+        0.667, 0.333, 0.500,
+        0.667, 0.667, 0.500,
+        0.667, 1.000, 0.500,
+        1.000, 0.000, 0.500,
+        1.000, 0.333, 0.500,
+        1.000, 0.667, 0.500,
+        1.000, 1.000, 0.500,
+        0.000, 0.333, 1.000,
+        0.000, 0.667, 1.000,
+        0.000, 1.000, 1.000,
+        0.333, 0.000, 1.000,
+        0.333, 0.333, 1.000,
+        0.333, 0.667, 1.000,
+        0.333, 1.000, 1.000,
+        0.667, 0.000, 1.000,
+        0.667, 0.333, 1.000,
+        0.667, 0.667, 1.000,
+        0.667, 1.000, 1.000,
+        1.000, 0.000, 1.000,
+        1.000, 0.333, 1.000,
+        1.000, 0.667, 1.000,
+        0.333, 0.000, 0.000,
+        0.500, 0.000, 0.000,
+        0.667, 0.000, 0.000,
+        0.833, 0.000, 0.000,
+        1.000, 0.000, 0.000,
+        0.000, 0.167, 0.000,
+        0.000, 0.333, 0.000,
+        0.000, 0.500, 0.000,
+        0.000, 0.667, 0.000,
+        0.000, 0.833, 0.000,
+        0.000, 1.000, 0.000,
+        0.000, 0.000, 0.167,
+        0.000, 0.000, 0.333,
+        0.000, 0.000, 0.500,
+        0.000, 0.000, 0.667,
+        0.000, 0.000, 0.833,
+        0.000, 0.000, 1.000,
+        0.000, 0.000, 0.000,
+        0.143, 0.143, 0.143,
+        0.286, 0.286, 0.286,
+        0.429, 0.429, 0.429,
+        0.571, 0.571, 0.571,
+        0.714, 0.714, 0.714,
+        0.857, 0.857, 0.857,
+        0.000, 0.447, 0.741,
+        0.314, 0.717, 0.741,
+        0.50, 0.5, 0
+    ]
+).astype(np.float32).reshape(-1, 3)
diff --git a/multimodal/batch_submit.sh b/multimodal/batch_submit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8f40ac45549e2a83c82cd14a57f4822a3790e8c9
--- /dev/null
+++ b/multimodal/batch_submit.sh
@@ -0,0 +1,488 @@
+# sbatch -J label0 submit_labeling.sh 0 48
+# sleep 30
+sbatch -J label1 submit_labeling.sh 48 96
+sleep 30
+sbatch -J label2 submit_labeling.sh 96 144
+sleep 30
+sbatch -J label3 submit_labeling.sh 144 192
+sleep 30
+sbatch -J label4 submit_labeling.sh 192 240
+sleep 30
+sbatch -J label5 submit_labeling.sh 240 288
+sleep 30
+sbatch -J label6 submit_labeling.sh 288 336
+sleep 30
+sbatch -J label7 submit_labeling.sh 336 384
+sleep 30
+sbatch -J label8 submit_labeling.sh 384 432
+sleep 30
+sbatch -J label9 submit_labeling.sh 432 480
+sleep 30
+sbatch -J label10 submit_labeling.sh 480 528
+sleep 30
+sbatch -J label11 submit_labeling.sh 528 576
+sleep 30
+sbatch -J label12 submit_labeling.sh 576 624
+sleep 30
+sbatch -J label13 submit_labeling.sh 624 672
+sleep 30
+sbatch -J label14 submit_labeling.sh 672 720
+sleep 30
+sbatch -J label15 submit_labeling.sh 720 768
+sleep 30
+sbatch -J label16 submit_labeling.sh 768 816
+sleep 30
+sbatch -J label17 submit_labeling.sh 816 864
+sleep 30
+sbatch -J label18 submit_labeling.sh 864 912
+sleep 30
+sbatch -J label19 submit_labeling.sh 912 960
+sleep 30
+sbatch -J label20 submit_labeling.sh 960 1008
+sleep 30
+sbatch -J label21 submit_labeling.sh 1008 1056
+sleep 30
+sbatch -J label22 submit_labeling.sh 1056 1104
+sleep 30
+sbatch -J label23 submit_labeling.sh 1104 1152
+sleep 30
+sbatch -J label24 submit_labeling.sh 1152 1200
+sleep 30
+sbatch -J label25 submit_labeling.sh 1200 1248
+sleep 30
+sbatch -J label26 submit_labeling.sh 1248 1296
+sleep 30
+sbatch -J label27 submit_labeling.sh 1296 1344
+sleep 30
+sbatch -J label28 submit_labeling.sh 1344 1392
+sleep 30
+sbatch -J label29 submit_labeling.sh 1392 1440
+sleep 30
+sbatch -J label30 submit_labeling.sh 1440 1488
+sleep 30
+sbatch -J label31 submit_labeling.sh 1488 1536
+sleep 30
+sbatch -J label32 submit_labeling.sh 1536 1584
+sleep 30
+sbatch -J label33 submit_labeling.sh 1584 1632
+sleep 30
+sbatch -J label34 submit_labeling.sh 1632 1680
+sleep 30
+sbatch -J label35 submit_labeling.sh 1680 1728
+sleep 30
+sbatch -J label36 submit_labeling.sh 1728 1776
+sleep 30
+sbatch -J label37 submit_labeling.sh 1776 1824
+sleep 30
+sbatch -J label38 submit_labeling.sh 1824 1872
+sleep 30
+sbatch -J label39 submit_labeling.sh 1872 1920
+sleep 30
+sbatch -J label40 submit_labeling.sh 1920 1968
+sleep 30
+sbatch -J label41 submit_labeling.sh 1968 2016
+sleep 30
+sbatch -J label42 submit_labeling.sh 2016 2064
+sleep 30
+sbatch -J label43 submit_labeling.sh 2064 2112
+sleep 30
+sbatch -J label44 submit_labeling.sh 2112 2160
+sleep 30
+sbatch -J label45 submit_labeling.sh 2160 2208
+sleep 30
+sbatch -J label46 submit_labeling.sh 2208 2256
+sleep 30
+sbatch -J label47 submit_labeling.sh 2256 2304
+sleep 30
+sbatch -J label48 submit_labeling.sh 2304 2352
+sleep 30
+sbatch -J label49 submit_labeling.sh 2352 2400
+sleep 30
+sbatch -J label50 submit_labeling.sh 2400 2448
+sleep 30
+sbatch -J label51 submit_labeling.sh 2448 2496
+sleep 30
+sbatch -J label52 submit_labeling.sh 2496 2544
+sleep 30
+sbatch -J label53 submit_labeling.sh 2544 2592
+sleep 30
+sbatch -J label54 submit_labeling.sh 2592 2640
+sleep 30
+sbatch -J label55 submit_labeling.sh 2640 2688
+sleep 30
+sbatch -J label56 submit_labeling.sh 2688 2736
+sleep 30
+sbatch -J label57 submit_labeling.sh 2736 2784
+sleep 30
+sbatch -J label58 submit_labeling.sh 2784 2832
+sleep 30
+sbatch -J label59 submit_labeling.sh 2832 2880
+sleep 30
+sbatch -J label60 submit_labeling.sh 2880 2928
+sleep 30
+sbatch -J label61 submit_labeling.sh 2928 2976
+sleep 30
+sbatch -J label62 submit_labeling.sh 2976 3024
+sleep 30
+sbatch -J label63 submit_labeling.sh 3024 3072
+sleep 30
+sbatch -J label64 submit_labeling.sh 3072 3120
+sleep 30
+sbatch -J label65 submit_labeling.sh 3120 3168
+sleep 30
+sbatch -J label66 submit_labeling.sh 3168 3216
+sleep 30
+sbatch -J label67 submit_labeling.sh 3216 3264
+sleep 30
+sbatch -J label68 submit_labeling.sh 3264 3312
+sleep 30
+sbatch -J label69 submit_labeling.sh 3312 3360
+sleep 30
+sbatch -J label70 submit_labeling.sh 3360 3408
+sleep 30
+sbatch -J label71 submit_labeling.sh 3408 3456
+sleep 30
+sbatch -J label72 submit_labeling.sh 3456 3504
+sleep 30
+sbatch -J label73 submit_labeling.sh 3504 3552
+sleep 30
+sbatch -J label74 submit_labeling.sh 3552 3600
+sleep 30
+sbatch -J label75 submit_labeling.sh 3600 3648
+sleep 30
+sbatch -J label76 submit_labeling.sh 3648 3696
+sleep 30
+sbatch -J label77 submit_labeling.sh 3696 3744
+sleep 30
+sbatch -J label78 submit_labeling.sh 3744 3792
+sleep 30
+sbatch -J label79 submit_labeling.sh 3792 3840
+sleep 30
+sbatch -J label80 submit_labeling.sh 3840 3888
+sleep 30
+sbatch -J label81 submit_labeling.sh 3888 3936
+sleep 30
+sbatch -J label82 submit_labeling.sh 3936 3984
+sleep 30
+sbatch -J label83 submit_labeling.sh 3984 4032
+sleep 30
+sbatch -J label84 submit_labeling.sh 4032 4080
+sleep 30
+sbatch -J label85 submit_labeling.sh 4080 4128
+sleep 30
+sbatch -J label86 submit_labeling.sh 4128 4176
+sleep 30
+sbatch -J label87 submit_labeling.sh 4176 4224
+sleep 30
+sbatch -J label88 submit_labeling.sh 4224 4272
+sleep 30
+sbatch -J label89 submit_labeling.sh 4272 4320
+sleep 30
+sbatch -J label90 submit_labeling.sh 4320 4368
+sleep 30
+sbatch -J label91 submit_labeling.sh 4368 4416
+sleep 30
+sbatch -J label92 submit_labeling.sh 4416 4464
+sleep 30
+sbatch -J label93 submit_labeling.sh 4464 4512
+sleep 30
+sbatch -J label94 submit_labeling.sh 4512 4560
+sleep 30
+sbatch -J label95 submit_labeling.sh 4560 4608
+sleep 30
+sbatch -J label96 submit_labeling.sh 4608 4656
+sleep 30
+sbatch -J label97 submit_labeling.sh 4656 4704
+sleep 30
+sbatch -J label98 submit_labeling.sh 4704 4752
+sleep 30
+sbatch -J label99 submit_labeling.sh 4752 4800
+sleep 30
+sbatch -J label100 submit_labeling.sh 4800 4848
+sleep 30
+sbatch -J label101 submit_labeling.sh 4848 4896
+sleep 30
+sbatch -J label102 submit_labeling.sh 4896 4944
+sleep 30
+sbatch -J label103 submit_labeling.sh 4944 4992
+sleep 30
+sbatch -J label104 submit_labeling.sh 4992 5040
+sleep 30
+sbatch -J label105 submit_labeling.sh 5040 5088
+sleep 30
+sbatch -J label106 submit_labeling.sh 5088 5136
+sleep 30
+sbatch -J label107 submit_labeling.sh 5136 5184
+sleep 30
+sbatch -J label108 submit_labeling.sh 5184 5232
+sleep 30
+sbatch -J label109 submit_labeling.sh 5232 5280
+sleep 30
+sbatch -J label110 submit_labeling.sh 5280 5328
+sleep 30
+sbatch -J label111 submit_labeling.sh 5328 5376
+sleep 30
+sbatch -J label112 submit_labeling.sh 5376 5424
+sleep 30
+sbatch -J label113 submit_labeling.sh 5424 5472
+sleep 30
+sbatch -J label114 submit_labeling.sh 5472 5520
+sleep 30
+sbatch -J label115 submit_labeling.sh 5520 5568
+sleep 30
+sbatch -J label116 submit_labeling.sh 5568 5616
+sleep 30
+sbatch -J label117 submit_labeling.sh 5616 5664
+sleep 30
+sbatch -J label118 submit_labeling.sh 5664 5712
+sleep 30
+sbatch -J label119 submit_labeling.sh 5712 5760
+sleep 30
+sbatch -J label120 submit_labeling.sh 5760 5808
+sleep 30
+sbatch -J label121 submit_labeling.sh 5808 5856
+sleep 30
+sbatch -J label122 submit_labeling.sh 5856 5904
+sleep 30
+sbatch -J label123 submit_labeling.sh 5904 5952
+sleep 30
+sbatch -J label124 submit_labeling.sh 5952 6000
+sleep 30
+sbatch -J label125 submit_labeling.sh 6000 6048
+sleep 30
+sbatch -J label126 submit_labeling.sh 6048 6096
+sleep 30
+sbatch -J label127 submit_labeling.sh 6096 6144
+sleep 30
+sbatch -J label128 submit_labeling.sh 6144 6192
+sleep 30
+sbatch -J label129 submit_labeling.sh 6192 6240
+sleep 30
+sbatch -J label130 submit_labeling.sh 6240 6288
+sleep 30
+sbatch -J label131 submit_labeling.sh 6288 6336
+sleep 30
+sbatch -J label132 submit_labeling.sh 6336 6384
+sleep 30
+sbatch -J label133 submit_labeling.sh 6384 6432
+sleep 30
+sbatch -J label134 submit_labeling.sh 6432 6480
+sleep 30
+sbatch -J label135 submit_labeling.sh 6480 6528
+sleep 30
+sbatch -J label136 submit_labeling.sh 6528 6576
+sleep 30
+sbatch -J label137 submit_labeling.sh 6576 6624
+sleep 30
+sbatch -J label138 submit_labeling.sh 6624 6672
+sleep 30
+sbatch -J label139 submit_labeling.sh 6672 6720
+sleep 30
+sbatch -J label140 submit_labeling.sh 6720 6768
+sleep 30
+sbatch -J label141 submit_labeling.sh 6768 6816
+sleep 30
+sbatch -J label142 submit_labeling.sh 6816 6864
+sleep 30
+sbatch -J label143 submit_labeling.sh 6864 6912
+sleep 30
+sbatch -J label144 submit_labeling.sh 6912 6960
+sleep 30
+sbatch -J label145 submit_labeling.sh 6960 7008
+sleep 30
+sbatch -J label146 submit_labeling.sh 7008 7056
+sleep 30
+sbatch -J label147 submit_labeling.sh 7056 7104
+sleep 30
+sbatch -J label148 submit_labeling.sh 7104 7152
+sleep 30
+sbatch -J label149 submit_labeling.sh 7152 7200
+sleep 30
+sbatch -J label150 submit_labeling.sh 7200 7248
+sleep 30
+sbatch -J label151 submit_labeling.sh 7248 7296
+sleep 30
+sbatch -J label152 submit_labeling.sh 7296 7344
+sleep 30
+sbatch -J label153 submit_labeling.sh 7344 7392
+sleep 30
+sbatch -J label154 submit_labeling.sh 7392 7440
+sleep 30
+sbatch -J label155 submit_labeling.sh 7440 7488
+sleep 30
+sbatch -J label156 submit_labeling.sh 7488 7536
+sleep 30
+sbatch -J label157 submit_labeling.sh 7536 7584
+sleep 30
+sbatch -J label158 submit_labeling.sh 7584 7632
+sleep 30
+sbatch -J label159 submit_labeling.sh 7632 7680
+sleep 30
+sbatch -J label160 submit_labeling.sh 7680 7728
+sleep 30
+sbatch -J label161 submit_labeling.sh 7728 7776
+sleep 30
+sbatch -J label162 submit_labeling.sh 7776 7824
+sleep 30
+sbatch -J label163 submit_labeling.sh 7824 7872
+sleep 30
+sbatch -J label164 submit_labeling.sh 7872 7920
+sleep 30
+sbatch -J label165 submit_labeling.sh 7920 7968
+sleep 30
+sbatch -J label166 submit_labeling.sh 7968 8016
+sleep 30
+sbatch -J label167 submit_labeling.sh 8016 8064
+sleep 30
+sbatch -J label168 submit_labeling.sh 8064 8112
+sleep 30
+sbatch -J label169 submit_labeling.sh 8112 8160
+sleep 30
+sbatch -J label170 submit_labeling.sh 8160 8208
+sleep 30
+sbatch -J label171 submit_labeling.sh 8208 8256
+sleep 30
+sbatch -J label172 submit_labeling.sh 8256 8304
+sleep 30
+sbatch -J label173 submit_labeling.sh 8304 8352
+sleep 30
+sbatch -J label174 submit_labeling.sh 8352 8400
+sleep 30
+sbatch -J label175 submit_labeling.sh 8400 8448
+sleep 30
+sbatch -J label176 submit_labeling.sh 8448 8496
+sleep 30
+sbatch -J label177 submit_labeling.sh 8496 8544
+sleep 30
+sbatch -J label178 submit_labeling.sh 8544 8592
+sleep 30
+sbatch -J label179 submit_labeling.sh 8592 8640
+sleep 30
+sbatch -J label180 submit_labeling.sh 8640 8688
+sleep 30
+sbatch -J label181 submit_labeling.sh 8688 8736
+sleep 30
+sbatch -J label182 submit_labeling.sh 8736 8784
+sleep 30
+sbatch -J label183 submit_labeling.sh 8784 8832
+sleep 30
+sbatch -J label184 submit_labeling.sh 8832 8880
+sleep 30
+sbatch -J label185 submit_labeling.sh 8880 8928
+sleep 30
+sbatch -J label186 submit_labeling.sh 8928 8976
+sleep 30
+sbatch -J label187 submit_labeling.sh 8976 9024
+sleep 30
+sbatch -J label188 submit_labeling.sh 9024 9072
+sleep 30
+sbatch -J label189 submit_labeling.sh 9072 9120
+sleep 30
+sbatch -J label190 submit_labeling.sh 9120 9168
+sleep 30
+sbatch -J label191 submit_labeling.sh 9168 9216
+sleep 30
+sbatch -J label192 submit_labeling.sh 9216 9264
+sleep 30
+sbatch -J label193 submit_labeling.sh 9264 9312
+sleep 30
+sbatch -J label194 submit_labeling.sh 9312 9360
+sleep 30
+sbatch -J label195 submit_labeling.sh 9360 9408
+sleep 30
+sbatch -J label196 submit_labeling.sh 9408 9456
+sleep 30
+sbatch -J label197 submit_labeling.sh 9456 9504
+sleep 30
+sbatch -J label198 submit_labeling.sh 9504 9552
+sleep 30
+sbatch -J label199 submit_labeling.sh 9552 9600
+sleep 30
+sbatch -J label200 submit_labeling.sh 9600 9648
+sleep 30
+sbatch -J label201 submit_labeling.sh 9648 9696
+sleep 30
+sbatch -J label202 submit_labeling.sh 9696 9744
+sleep 30
+sbatch -J label203 submit_labeling.sh 9744 9792
+sleep 30
+sbatch -J label204 submit_labeling.sh 9792 9840
+sleep 30
+sbatch -J label205 submit_labeling.sh 9840 9888
+sleep 30
+sbatch -J label206 submit_labeling.sh 9888 9936
+sleep 30
+sbatch -J label207 submit_labeling.sh 9936 9984
+sleep 30
+sbatch -J label208 submit_labeling.sh 9984 10032
+sleep 30
+sbatch -J label209 submit_labeling.sh 10032 10080
+sleep 30
+sbatch -J label210 submit_labeling.sh 10080 10128
+sleep 30
+sbatch -J label211 submit_labeling.sh 10128 10176
+sleep 30
+sbatch -J label212 submit_labeling.sh 10176 10224
+sleep 30
+sbatch -J label213 submit_labeling.sh 10224 10272
+sleep 30
+sbatch -J label214 submit_labeling.sh 10272 10320
+sleep 30
+sbatch -J label215 submit_labeling.sh 10320 10368
+sleep 30
+sbatch -J label216 submit_labeling.sh 10368 10416
+sleep 30
+sbatch -J label217 submit_labeling.sh 10416 10464
+sleep 30
+sbatch -J label218 submit_labeling.sh 10464 10512
+sleep 30
+sbatch -J label219 submit_labeling.sh 10512 10560
+sleep 30
+sbatch -J label220 submit_labeling.sh 10560 10608
+sleep 30
+sbatch -J label221 submit_labeling.sh 10608 10656
+sleep 30
+sbatch -J label222 submit_labeling.sh 10656 10704
+sleep 30
+sbatch -J label223 submit_labeling.sh 10704 10752
+sleep 30
+sbatch -J label224 submit_labeling.sh 10752 10800
+sleep 30
+sbatch -J label225 submit_labeling.sh 10800 10848
+sleep 30
+sbatch -J label226 submit_labeling.sh 10848 10896
+sleep 30
+sbatch -J label227 submit_labeling.sh 10896 10944
+sleep 30
+sbatch -J label228 submit_labeling.sh 10944 10992
+sleep 30
+sbatch -J label229 submit_labeling.sh 10992 11040
+sleep 30
+sbatch -J label230 submit_labeling.sh 11040 11088
+sleep 30
+sbatch -J label231 submit_labeling.sh 11088 11136
+sleep 30
+sbatch -J label232 submit_labeling.sh 11136 11184
+sleep 30
+sbatch -J label233 submit_labeling.sh 11184 11232
+sleep 30
+sbatch -J label234 submit_labeling.sh 11232 11280
+sleep 30
+sbatch -J label235 submit_labeling.sh 11280 11328
+sleep 30
+sbatch -J label236 submit_labeling.sh 11328 11376
+sleep 30
+sbatch -J label237 submit_labeling.sh 11376 11424
+sleep 30
+sbatch -J label238 submit_labeling.sh 11424 11472
+sleep 30
+sbatch -J label239 submit_labeling.sh 11472 11520
+sleep 30
+sbatch -J label240 submit_labeling.sh 11520 11568
+sleep 30
+sbatch -J label241 submit_labeling.sh 11568 11616
+sleep 30
+sbatch -J label242 submit_labeling.sh 11616 11664
+sleep 30
+sbatch -J label243 submit_labeling.sh 11664 11699
+sleep 30
diff --git a/multimodal/bertscore_eval.py b/multimodal/bertscore_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..c64a68866c05f32687a461eb4e2aa30d32cdabd5
--- /dev/null
+++ b/multimodal/bertscore_eval.py
@@ -0,0 +1,63 @@
+from open_flamingo.eval.vqa_metric import compute_vqa_accuracy
+import sys
+import json
+from bert_score import BERTScorer
+from tqdm.contrib.concurrent import process_map
+from tqdm import tqdm
+import random
+import time
+NUM_GPU = 128
+
+def single_job(args):
+    data, refs, idx = args
+    success = False
+    while not success:
+        try:
+            time.sleep(random.random()*10)
+            scorer = BERTScorer(
+                lang="en",
+                rescale_with_baseline=True,
+                # model_type="microsoft/deberta-xlarge-mnli",
+                model_type="bert-base-uncased",
+                batch_size=4096,
+                device=f"cuda:{idx % 6}"
+            )
+            success = True
+        except:
+            time.sleep(random.random()*5)
+    for i, d in enumerate(tqdm(data, disable=(idx != 0))):
+        if d["answer"] == "":
+            continue
+        cands = [d["answer"]] * len(refs)
+        P, R, F1 = scorer.score(cands, refs, verbose=False)
+        d["answer"] = refs[F1.argmax()]
+        data[i] = d
+    return data
+
+
+if __name__ == "__main__":
+    if sys.argv[1] == "vqav2":
+        question_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_OpenEnded_mscoco_val2014_questions.json"
+        annotation_json_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/task/open_flamingo/vqav2/v2_mscoco_val2014_annotations.json"
+    else:
+        raise NotImplementedError
+    answer_list = json.load(open("answer_list.json"))
+    data = json.load(open(sys.argv[2]))
+    cands = []
+    refs = []
+    data_parts = []
+    for i in range(NUM_GPU):
+        data_parts.append([[], answer_list, i])
+    for i, d in enumerate(data):
+        data_parts[i % NUM_GPU][0].append(d)
+    datas = process_map(single_job, data_parts, max_workers=NUM_GPU, disable=True)
+    all_data = []
+    for data in datas:
+        all_data.extend(data)
+    json.dump(all_data, open("temp_result", "w"))
+    acc = compute_vqa_accuracy(
+        result_json_path="temp_result",
+        question_json_path=question_json_path,
+        annotation_json_path=annotation_json_path,
+    )
+    print(acc)
diff --git a/multimodal/draw_all.py b/multimodal/draw_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..935e320d4e2345066970dd27db240286b797ce19
--- /dev/null
+++ b/multimodal/draw_all.py
@@ -0,0 +1,51 @@
+import glob
+import os
+import matplotlib.pyplot as plt
+COLORS = ["green", "blue", "orange", "black", "purple", "gray", "gold", "red", "gold", "yellow"]
+color_i = 0
+
+if __name__ == "__main__":
+    data = {}
+    color = {}
+    filenames = glob.glob("./eval_results/*")
+    for filename in filenames:
+        if "ig" in filename or "0713" in filename or "0819" in filename:
+            continue
+        items = filename.split("/")[-1].split("_")
+        if len(items) < 5:
+            continue
+        task = items[0]
+        if task == "ok":
+            task = "okvqa"
+            exp = "_".join(items[2:-3])
+        else:
+            exp = "_".join(items[1:-3])
+        if "fix" not in exp:
+            step = int(items[-3])
+            if "13" in exp:
+                step //= 2
+            score = float(items[-1])
+            if task not in data:
+                data[task] = {}
+            if exp not in data[task]:
+                data[task][exp] = []
+            data[task][exp].append([step, score])
+            if exp not in color:
+                color[exp] = COLORS[color_i]
+                color_i += 1
+    for task in data:
+        for exp in data[task]:
+            data[task][exp] = sorted(data[task][exp], key=lambda x: x[0])
+    for task in data:
+        plt.figure()
+        plt.title(f"{task} evaluation")
+        for exp in data[task]:
+            steps = [x[0] for x in data[task][exp]]
+            scores = [x[1] for x in data[task][exp]]
+            plt.plot(steps, scores, "-o", color=color[exp], label=exp)
+        plt.grid()
+        plt.legend()
+        plt.xlabel("step")
+        plt.xlim(0, 15000)
+        plt.savefig(f"eval_results/{task}.jpg")
+
diff --git a/multimodal/environment.yml b/multimodal/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..2c010aa1a627eedc92badd5d54f55956bf617751
--- /dev/null
+++ b/multimodal/environment.yml
@@ -0,0 +1,10 @@
+name: mm
+channels:
+  - defaults
+dependencies:
+  - python=3.9
+  - conda-forge::openjdk
+  - pip
+  - pip:
+    - -r requirements.txt
+    - -e .
diff --git a/multimodal/example.py b/multimodal/example.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ad929f1c5e05be634e6068afce3f5bbad1f2940
--- /dev/null
+++ b/multimodal/example.py
@@ -0,0 +1,33 @@
+build_model = None
+ZeroRedundancyOptimizer = None
+GradScaler = None
+laion_loader = None
+pile_loader = None
+autocast = None
+zero_embedding_gradient = None
+torch = None
+lr_scheduler = None
+get_cosine_schedule_with_warmup = None
+
+
+ddp_model = build_model(...)
+optimizer = ZeroRedundancyOptimizer(...)
+lr_scheduler = get_cosine_schedule_with_warmup(...)
+scaler = GradScaler()
+
+for batch_laion, batch_pile in zip(laion_loader, pile_loader):
+    with autocast():
+        loss_laion = ddp_model(batch_laion)
+    scaler.scale(loss_laion).backward()
+    with autocast():
+        loss_pile = ddp_model(batch_pile)
+    scaler.scale(loss_pile).backward()
+
+    zero_embedding_gradient()
+    scaler.unscale_(optimizer)
+    torch.nn.utils.clip_grad_norm_(ddp_model.parameters(), 1.0)
+
+    scaler.step(optimizer)
+    scaler.update()
+    lr_scheduler.step()
+    optimizer.zero_grad()
diff --git a/multimodal/four_cats.png b/multimodal/four_cats.png
new file mode 100644
index 0000000000000000000000000000000000000000..56d981d712d7b96fe388ef6fb9cb8302eb23de81
Binary files /dev/null and b/multimodal/four_cats.png differ
diff --git a/multimodal/generate_batch_submit.py b/multimodal/generate_batch_submit.py
new file mode 100644
index 0000000000000000000000000000000000000000..067cdf0d2c30fc6cef0ab3e27dc1dc0ae3af9d0d
--- /dev/null
+++ b/multimodal/generate_batch_submit.py
@@ -0,0 +1,9 @@
+import os
+import sys
+start_idx = sys.argv[1]
+end_idx = sys.argv[2]
+
+with open("batch_submit.sh", "w") as f:
+    for i, idx in enumerate(range(int(start_idx), int(end_idx), 48)):
+        f.write(f"sbatch -J label{i} submit_labeling.sh {idx} {idx+48}\n")
+        f.write("sleep 30\n")
diff --git a/multimodal/interp_sam.py b/multimodal/interp_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b1adb52dd69756a2cbe9c739569091dfec0176a
--- /dev/null
+++ b/multimodal/interp_sam.py
@@ -0,0 +1,30 @@
+import os
+import torch
+ORI_IMAGE_SIZE = 1024
+IMAGE_SIZE = 256
+REL_POS = 31
+
+checkpoint = torch.load("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195.pth")
+image_encoder_pos_embed = checkpoint["image_encoder.pos_embed"]
+image_encoder_pos_embed = torch.nn.functional.interpolate(image_encoder_pos_embed.permute(0, 3, 1, 2), scale_factor=IMAGE_SIZE / ORI_IMAGE_SIZE, mode="bilinear", align_corners=True).permute(0, 2, 3, 1)
+checkpoint["image_encoder.pos_embed"] = image_encoder_pos_embed
+print(image_encoder_pos_embed.shape)
+
+for idx in [5, 11, 17, 23]:
+    rel_pos_h = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"]
+    rel_pos_w = checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"]
+    rel_pos_h = torch.nn.functional.interpolate(
+        rel_pos_h.permute(1, 0).unsqueeze(0),
+        size=REL_POS, mode="linear",
+        align_corners=True,
+    ).squeeze(0).permute(1, 0)
+    rel_pos_w = torch.nn.functional.interpolate(
+        rel_pos_w.permute(1, 0).unsqueeze(0),
+        size=REL_POS, mode="linear",
+        align_corners=True,
+    ).squeeze(0).permute(1, 0)
+    checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_h"] = rel_pos_h
+    checkpoint[f"image_encoder.blocks.{idx}.attn.rel_pos_w"] = rel_pos_w
+    print(rel_pos_h.shape, rel_pos_w.shape)
+
+torch.save(checkpoint, f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_{IMAGE_SIZE}x{IMAGE_SIZE}.pth")
diff --git a/multimodal/local_train.sh b/multimodal/local_train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7555e1affba77826372d076a6eb70b4686a46c47
--- /dev/null
+++ b/multimodal/local_train.sh
@@ -0,0 +1,38 @@
+LAION_DATA=/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/karpathy_coco_wds_full_ground/{00000..00066}.tar
+PILE_DATA=/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/the_pile/{00000..01925}.tar
+SAVE_DIR=checkpoints_local/debug0922
+mkdir -p ${SAVE_DIR}
+cp $0 ${SAVE_DIR}/
+
+export TRANSFORMERS_OFFLINE=1
+torchrun --nnodes=1 --nproc_per_node=6 --master_port=14288 open_flamingo/train/train.py \
+--run_name ${SAVE_DIR} \
+--vision_encoder_path ViT-L-14 \
+--vision_encoder_pretrained datacomp_xl_s13b_b90k \
+--lm_path EleutherAI/pythia-1.4b \
+--tokenizer_path EleutherAI/pythia-1.4b \
+--dataset_resampled \
+--laion_shards ${LAION_DATA} \
+--pile_shards ${PILE_DATA} \
+--batch_size_laion 14 \
+--batch_size_pile 2 \
+--workers=4 \
+--lr_scheduler cosine \
+--warmup_steps 200 \
+--num_steps 4000 \
+--checkpoint_activations \
+--delete_previous_checkpoint \
+--gradient_accumulation_steps 1 \
+--save_interval 100 \
+--logging_steps 2 \
+--skip_delete_pattern 500 \
+--precision amp_fp16 \
+--learning_rate 1.0e-5 \
+--add_visual_token \
+--max-length 960 \
+--loss_multiplier_det 0.025 \
+--add_box \
+--expand \
+--use_format_v2 \
+--resume_from_checkpoint checkpoints/091701_pythiaS_previsual_fix/checkpoint_20000.pt \
+--restart
diff --git a/multimodal/offline_grounding_dino.py b/multimodal/offline_grounding_dino.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e55b5b145cdb0e83be91a418b1a800d813f7034
--- /dev/null
+++ b/multimodal/offline_grounding_dino.py
@@ -0,0 +1,68 @@
+import webdataset as wds
+from groundingdino.demo.caption_grounder import caption_grounder
+from tqdm import tqdm
+import sys
+import os
+
+# SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/laion_synthetic_filtered_large/all"
+# DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/laion_synthetic_filtered_large/all_ground"
+
+# SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/ccs_synthetic_filtered_large"
+# DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/ccs_synthetic_filtered_large_ground"
+
+# SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/karpathy_coco_wds_full"
+# DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/karpathy_coco_wds_full_ground"
+
+# SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_wds_full"
+# DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_wds_full_ground"
+SOURCE_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/all_data_0620"
+DEST_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/all_data_ground_0701"
+
+def augment_wds(url, output, generator):
+    src = (
+        wds.WebDataset(url)
+        .decode("pilrgb")
+        .to_tuple("__key__", "jpg;png;jpeg", "txt")
+    )
+    
+    with wds.TarWriter(output) as dst:
+        for key, image, caption in tqdm(src, total=10000):
+            # jpg txt json
+            # image = image.resize((224, 224))
+            logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption)
+            sample = {
+                "__key__": key,
+                "jpg": image,
+                "txt": caption,
+                "logits.pyd": logits,
+                "boxes.pyd": boxes,
+            }
+            dst.write(sample)
+
+
+if __name__ == "__main__":
+    print("FROM", os.path.join(SOURCE_DIR, sys.argv[2]+".tar"))
+    print("TO", os.path.join(DEST_DIR, sys.argv[2]+".tar"))
+    # if os.path.exists(os.path.join(DEST_DIR, sys.argv[2]+".tar")):
+    #     print("already done. exiting...")
+    #     exit()
+    success = False
+    while not success:
+        try:
+            generator = caption_grounder(
+                config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinB.cfg.py",
+                checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swinb_cogcoor.pth",
+                cpu_only=False,
+                box_threshold=0.05,
+            )
+            success = True
+        except:
+            import random
+            import time
+            time.sleep(random.random() * 5)
+    augment_wds(
+        os.path.join(SOURCE_DIR, sys.argv[2]+".tar"),
+        os.path.join(DEST_DIR, sys.argv[2]+".tar"),
+        generator=generator,
+    )
+    print("DONE")
diff --git a/multimodal/offline_labeling.py b/multimodal/offline_labeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9273a1c1a44f21f3e7197d7b60251cc7e71a81d
--- /dev/null
+++ b/multimodal/offline_labeling.py
@@ -0,0 +1,22 @@
+import os
+import sys
+GPU_PER_NODE = 6
+TASK_PER_GPU = 8
+
+if __name__ == "__main__":
+    split = sys.argv[1]
+    start_idx = sys.argv[2]
+    end_idx = sys.argv[3]
+    job_id = os.environ["SLURM_JOBID"]
+    gpu_id = 0
+    job_bash = f"temp/job/{job_id}.sh"
+    with open(job_bash, "w") as f:
+        f.write("export TRANSFORMERS_OFFLINE=1\n")
+        for i, idx in enumerate(range(int(start_idx), int(end_idx))):
+            zfill_idx = str(idx).zfill(6)
+            f.write(f"CUDA_VISIBLE_DEVICES={gpu_id} python3 offline_grounding_dino.py {split} {zfill_idx} &> temp/log/{split}_{zfill_idx}_{job_id}_{gpu_id}.txt &\n")
+            gpu_id = (gpu_id + 1) % GPU_PER_NODE
+        f.write("sleep 7200\n")
+    print("run!")
+    os.system(f"bash {job_bash}")
+    print("end!")
diff --git a/multimodal/open_flamingo/__init__.py b/multimodal/open_flamingo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab67750bb75534afaeb8876c065e32d4861f3052
--- /dev/null
+++ b/multimodal/open_flamingo/__init__.py
@@ -0,0 +1,2 @@
+from .src.flamingo import Flamingo
+from .src.factory import create_model_and_transforms
diff --git a/multimodal/open_flamingo/eval/__init__.py b/multimodal/open_flamingo/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/multimodal/open_flamingo/eval/__init__.py
@@ -0,0 +1 @@
+
diff --git a/multimodal/open_flamingo/eval/classification.py b/multimodal/open_flamingo/eval/classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6503b30a370faeb98410f158e50a8bc0e4484b2
--- /dev/null
+++ b/multimodal/open_flamingo/eval/classification.py
@@ -0,0 +1,147 @@
+from typing import Dict, Sequence, Tuple
+import re
+import numpy as np
+import torch
+
+
+def postprocess_classification_generation(predictions) -> str:
+    return re.split("Prompt|Completion", predictions, 1)[0]
+
+
+def compute_classification_accuracy(predictions: Sequence[Dict[str, str]]) -> float:
+    """Compute the accuracy of a sequence of predictions."""
+
+    def _preprocess_fn(s):
+        """Function to preprocess both targets and predictions."""
+        return s.lower()
+
+    is_correct = [
+        _preprocess_fn(x["prediction"]) == _preprocess_fn(x["class_label"])
+        for x in predictions
+    ]
+
+    return np.mean(is_correct).item()
+
+
+def compute_shifted_logits_and_labels(
+    logits: torch.Tensor, encodings, tokenizer, eoc_token_id
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Helper function to compute shifted logits and labels.
+
+    This allows for straightforward computation of the loss on shift_logits
+    and shift_labels such that the nth element of logits computes the n-1th
+    element of the original labels (in the outputs, the nth element of logits
+    corresponds to the nth element of the labels).
+
+    Elements in shift_labels that correspond to inputs are masked with values
+    of -100 (by default in hf, loss is only computed on token IDs >= 0).
+
+    Returns: tuple containing two elements:
+        shift_logits: a float Tensor of shape [batch_size, seq_len - 1].
+        shift_labels: an integer Tensor of shape [batch_size, seq_len - 1]
+    """
+
+    labels = encodings["input_ids"].clone()
+
+    # convert padding and EOC tokens to -100 so they are ignored in loss
+    labels[labels == tokenizer.pad_token_id] = -100
+    labels[labels == eoc_token_id] = -100
+
+    # Convert all tokens in prefix until separator to -100 so they are
+    # ignored in loss
+    for idx in range(len(labels)):
+        # Find the location of the last token of prefix *from right*,
+        # since the first non-padding token of the sequence will also be
+        # eos_token (because bos_token and eos_token are the same for
+        # the tokenizer).
+        end_of_prefix = -labels[idx].tolist()[::-1].index(tokenizer.eos_token_id) - 1
+        labels[idx, : end_of_prefix + 1] = -100
+
+    # Shift so that tokens < n predict n. The shifted tensors both have
+    # shape [batch_size, seq_len - 1].
+    shift_logits = logits[..., :-1, :].contiguous()
+    shift_labels = labels[..., 1:].contiguous()
+
+    return shift_logits, shift_labels
+
+
+def compute_per_sample_probs(
+    encodings, tokenizer, logits: torch.Tensor, eoc_token_id
+) -> torch.Tensor:
+    """Helper function to compute per-sample probability of the input sequence.
+
+    Assumes <eos token> is used to separate inputs from targets in the
+    prompt text
+    """
+    shift_logits, shift_labels = compute_shifted_logits_and_labels(
+        logits, encodings, tokenizer, eoc_token_id
+    )
+
+    # Tuple of tensors for unmasked label tokens. The first element of the
+    # tuple contains the batch indices; the second element contains the
+    # sequence indices.
+    unmasked_indices = torch.nonzero(shift_labels != -100, as_tuple=True)
+    # Tensor where the i^th element is the token_id corresponding to the i^th
+    # element of unmasked_indices
+    unmasked_token_ids = shift_labels[unmasked_indices]
+
+    # 3d tensor of [batch_idx, sequence_position, token_id] for unmasked tokens.
+    target_idxs = torch.column_stack([*unmasked_indices, unmasked_token_ids])
+    target_idxs = target_idxs.to(shift_logits.device)
+
+    # Sanity check that every element in batch has at least one unmasked
+    # target token
+    assert torch.all(
+        torch.bincount(target_idxs[:, 0]) != 0
+    ), "At least one element in batch has no unmasked target tokens."
+
+    # Renormalize over tokens to make sure they are proper probabilities via
+    # softmax over the token dimension.
+    shift_probs = torch.nn.functional.softmax(shift_logits, 2)
+
+    # Compute the probability of the target sequence (as the product of the
+    # probability of the individual tokens in the sequence).
+    target_probs = torch.ones(len(shift_labels), device=shift_logits.device)
+    for i, j, k in target_idxs:
+        target_probs[i] *= shift_probs[i, j, k]
+
+    return target_probs
+
+
+def compute_per_sample_loss(encodings, tokenizer, logits, eoc_token_id) -> torch.Tensor:
+    """Helper function to compute per-sample classification loss.
+
+    Assumes <eos token> is used to separate inputs from targets in the
+    prompt text
+    """
+    shift_logits, shift_labels = compute_shifted_logits_and_labels(
+        logits, encodings, tokenizer, eoc_token_id
+    )
+
+    device = shift_logits.device
+
+    # Loss is computed token-wise, on Tensors of shape
+    # [batch_size * (seq_len - 1), vocab_size]
+    # and returns a loss tensor of shape
+    # [batch_size * (seq_len - 1)]. Most of the tokens will be masked
+    # in this computation.
+    loss = torch.nn.functional.cross_entropy(
+        shift_logits.view(-1, shift_logits.size(-1)),
+        shift_labels.view(-1).to(device),
+        reduction="none",
+    )
+
+    # Reshape to [batch_size, seq_len - 1]
+    loss = loss.view(shift_logits.size(0), shift_logits.size(1)).cpu()
+
+    # loss_mask is 1 for tokens we want included in the loss, and 0 for tokens
+    # that should be ignored in the loss.
+    loss_mask = (shift_labels != -100).int().cpu()
+
+    loss *= loss_mask
+
+    # Compute per-element loss : sum loss over all (unmasked) tokens and
+    # divide by number of variable tokens to obtain tensor of
+    # shape [batch_size,]
+    loss = loss.sum(dim=1) / (shift_labels != -100).sum(dim=1).float()
+    return loss
diff --git a/multimodal/open_flamingo/eval/coco_metric.py b/multimodal/open_flamingo/eval/coco_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..db0e9a4154e9e72788f2353dec59ea0ee32187b8
--- /dev/null
+++ b/multimodal/open_flamingo/eval/coco_metric.py
@@ -0,0 +1,23 @@
+from pycocoevalcap.eval import COCOEvalCap
+from pycocotools.coco import COCO
+import json
+
+
+def compute_cider(
+    result_path,
+    annotations_path,
+):
+    # create coco object and coco_result object
+    coco = COCO(annotations_path)
+    coco_result = coco.loadRes(result_path)
+
+    # create coco_eval object by taking coco and coco_result
+    coco_eval = COCOEvalCap(coco, coco_result)
+    coco_eval.params["image_id"] = coco_result.getImgIds()
+    coco_eval.evaluate()
+
+    return coco_eval.eval
+
+
+def postprocess_captioning_generation(predictions):
+    return predictions
diff --git a/multimodal/open_flamingo/eval/dataset_zoo/__init__.py b/multimodal/open_flamingo/eval/dataset_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a3a44ad27ac4a4e9b1cc84bce075329b47141cc
--- /dev/null
+++ b/multimodal/open_flamingo/eval/dataset_zoo/__init__.py
@@ -0,0 +1,33 @@
+from .aro_datasets import VG_Relation, VG_Attribution, COCO_Order, Flickr30k_Order
+from .retrieval import COCO_Retrieval, Flickr30k_Retrieval
+
+
+def get_dataset(dataset_name, image_preprocess=None, text_perturb_fn=None, image_perturb_fn=None, download=False, *args, **kwargs):
+    """
+    Helper function that returns a dataset object with an evaluation function. 
+    dataset_name: Name of the dataset.
+    image_preprocess: Preprocessing function for images.
+    text_perturb_fn: A function that takes in a string and returns a string. This is for perturbation experiments.
+    image_perturb_fn: A function that takes in a PIL image and returns a PIL image. This is for perturbation experiments.
+    download: Whether to allow downloading images if they are not found.
+    """
+    if dataset_name == "VG_Relation": 
+        from .aro_datasets import get_visual_genome_relation
+        return get_visual_genome_relation(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
+    elif dataset_name == "VG_Attribution":
+        from .aro_datasets import get_visual_genome_attribution
+        return get_visual_genome_attribution(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
+    elif dataset_name == "COCO_Order":
+        from .aro_datasets import get_coco_order
+        return get_coco_order(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
+    elif dataset_name == "Flickr30k_Order":
+        from .aro_datasets import get_flickr30k_order
+        return get_flickr30k_order(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
+    elif dataset_name == "COCO_Retrieval":
+        from .retrieval import get_coco_retrieval
+        return get_coco_retrieval(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
+    elif dataset_name == "Flickr30k_Retrieval":
+        from .retrieval import get_flickr30k_retrieval
+        return get_flickr30k_retrieval(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download, *args, **kwargs)
+    else:
+        raise ValueError(f"Unknown dataset {dataset_name}")
diff --git a/multimodal/open_flamingo/eval/dataset_zoo/aro_datasets.py b/multimodal/open_flamingo/eval/dataset_zoo/aro_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..14e91d500c01dd232c8a7b23fb9af266ecd1d513
--- /dev/null
+++ b/multimodal/open_flamingo/eval/dataset_zoo/aro_datasets.py
@@ -0,0 +1,365 @@
+import os
+import json
+import subprocess
+
+import numpy as np
+
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset
+from easydict import EasyDict as edict
+from torchvision.datasets.utils import download_url
+
+from .perturbations import TextShuffler
+from .constants import ARO_ROOT, COCO_ROOT, FLICKR_ROOT
+from .retrieval import pre_caption
+
+
+class VG_Relation(Dataset):
+    def __init__(self, image_preprocess, text_perturb_fn=None, image_perturb_fn=None, root_dir=ARO_ROOT, download=False):
+        '''
+        image_preprocess: a function that takes in a PIL image and returns a tensor.
+        text_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
+        image_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
+        root_dir: Directory for the VG-R dataset.
+        download: Whether to download the dataset if it does not exist.
+        '''
+        self.root_dir = root_dir
+        annotation_file = os.path.join(root_dir, "visual_genome_relation.json")
+        image_dir = os.path.join(root_dir, "images")
+        if not os.path.exists(image_dir):
+            print("Image Directory for VG_Relation could not be found!")
+            if download:
+                self.download()
+            else:
+                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
+        
+        if not os.path.exists(annotation_file):
+            subprocess.call(["gdown", "--id", "1kX2iCHEv0CADL8dSO1nMdW-V0NqIAiP3", "--output", annotation_file])
+        
+        with open(annotation_file, "r") as f:
+            self.dataset = json.load(f)
+        
+        self.all_relations = list()
+        for item in self.dataset:
+            item["image_path"] = os.path.join(image_dir, item["image_path"])
+            self.all_relations.append(item["relation_name"])
+
+        self.image_preprocess = image_preprocess
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, index):
+        test_case = self.dataset[index]
+        image = Image.open(test_case["image_path"]).convert('RGB')
+        # Get the bounding box that contains the relation. This is to remove the irrelevant details in the scene.
+        image = image.crop((test_case["bbox_x"], test_case["bbox_y"], test_case["bbox_x"] + test_case["bbox_w"], test_case["bbox_y"] + test_case["bbox_h"]))
+
+        if self.image_preprocess is not None:
+            image = self.image_preprocess(image)
+
+        # Each test case has a correct and incorrect caption.
+        true_caption = test_case["true_caption"]
+        false_caption = test_case["false_caption"]
+        item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
+        return item
+    
+    def download(self):
+        os.makedirs(self.root_dir, exist_ok=True)
+        image_zip_file = os.path.join(self.root_dir, "vgr_vga_images.zip")
+        subprocess.call(["gdown", "--no-cookies", "1qaPlrwhGNMrR3a11iopZUT_GPP_LrgP9", "--output", image_zip_file])
+        subprocess.call(["unzip", "vgr_vga_images.zip"], cwd=self.root_dir)
+
+        
+    def evaluate_scores(self, scores):
+        """
+        Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
+        """
+        if isinstance(scores, tuple):
+            scores_i2t = scores[1]
+            scores_t2i = scores[0] 
+        else:
+            scores_t2i = scores
+            scores_i2t = scores
+
+        metrics = {"Accuracy": None}
+        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
+        correct_mask = (preds == 1)
+        metrics["Accuracy"] = np.mean(correct_mask)
+
+        all_relations = np.array(self.all_relations)
+
+        result_records = []
+        # Log the accuracy of all relations
+        for relation in np.unique(all_relations):
+            relation_mask = (all_relations == relation)
+            if relation_mask.sum() == 0:
+                continue
+            result_records.append({
+                "Relation": relation,
+                "Accuracy": correct_mask[relation_mask].mean(),
+                "Count": relation_mask.sum(),
+                "Dataset": "Visual Genome Relation"
+            })
+        return result_records
+
+
+
+class VG_Attribution(Dataset):
+    def __init__(self, image_preprocess, text_perturb_fn=None, image_perturb_fn=None, root_dir=ARO_ROOT, download=False):
+        '''
+        image_preprocess: a function that takes in a PIL image and returns a tensor.
+        text_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
+        image_perturb_fn: Not used for this dataset. Just for compatibility with other datasets.
+        root_dir: Directory for the VG-A dataset.
+        '''
+        self.root_dir = root_dir
+        annotation_file = os.path.join(root_dir, "visual_genome_attribution.json")
+        image_dir = os.path.join(root_dir, "images")
+        if not os.path.exists(image_dir):
+            print("Image Directory for VG_Attribution could not be found!")
+            if download:
+                self.download()
+            else:
+                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
+        
+        
+        if not os.path.exists(annotation_file):
+            subprocess.call(["gdown", "--id", "13tWvOrNOLHxl3Rm9cR3geAdHx2qR3-Tw", "--output", annotation_file])
+
+        with open(annotation_file, "r") as f:
+            self.dataset = json.load(f)
+        
+        for item in self.dataset:
+            item["image_path"] = os.path.join(image_dir, item["image_path"])
+        
+        # Set of attributes in each test case
+        self.all_attributes = [f"{item['attributes'][0]}_{item['attributes'][1]}" for item in self.dataset]
+        self.image_preprocess = image_preprocess
+
+    def __len__(self):
+        return len(self.dataset)
+
+    def __getitem__(self, index):
+        test_case = self.dataset[index]
+        image = Image.open(test_case["image_path"]).convert('RGB')
+        # Get the bounding box that contains the relation. This is to remove the irrelevant details in the scene.
+        image = image.crop((test_case["bbox_x"], test_case["bbox_y"], test_case["bbox_x"] + test_case["bbox_w"], test_case["bbox_y"] + test_case["bbox_h"]))
+
+        if self.image_preprocess is not None:
+            image = self.image_preprocess(image)
+
+        # Each test case has a correct and incorrect caption.
+        true_caption = test_case["true_caption"]
+        false_caption = test_case["false_caption"]
+        item = edict({"image_options": [image], "caption_options": [false_caption, true_caption]})
+        return item
+    
+    def download(self):
+        os.makedirs(self.root_dir, exist_ok=True)
+        image_zip_file = os.path.join(self.root_dir, "vgr_vga_images.zip")
+        subprocess.call(["gdown", "--no-cookies",  "1qaPlrwhGNMrR3a11iopZUT_GPP_LrgP9", "--output", image_zip_file])
+        subprocess.call(["unzip", "vgr_vga_images.zip"], cwd=self.root_dir)
+
+    
+    def evaluate_scores(self, scores):
+        """
+        Scores: N x 1 x 2, i.e. first caption is the perturbed one, second is the positive one
+        """
+        if isinstance(scores, tuple):
+            scores_i2t = scores[1]
+            scores_t2i = scores[0] 
+        else:
+            scores_t2i = scores
+            scores_i2t = scores
+
+        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
+        correct_mask = (preds == 1)
+        result_records = []
+        all_attributes = np.array(self.all_attributes)
+        for attr in np.unique(all_attributes):
+            attr_mask = (all_attributes == attr)
+            if attr_mask.sum() < 25:
+                continue
+            result_records.append({
+                "Attributes": attr,
+                "Accuracy": correct_mask[attr_mask].mean(),
+                "Count": attr_mask.sum(),
+                "Dataset": "Visual Genome Attribution"
+            })
+        return result_records
+
+
+
+
+class COCO_Order(Dataset):
+    def __init__(self, image_preprocess=None, root_dir=COCO_ROOT, max_words=30, split="test",
+                 image_perturb_fn=None, download=False):  
+        """
+        COCO Order Dataset.
+        image_preprocess: image preprocessing function
+        root_dir: The directory of the coco dataset. This directory should contain test2014 files.
+        max_words: Cropping the caption to max_words.
+        split: 'val' or 'test'
+        image_perturb_fn: not used; for compatibility.
+        download: Whether to download the dataset if it does not exist.
+        """
+        shuffler = TextShuffler()
+        perturb_functions = [shuffler.shuffle_nouns_and_adj, shuffler.shuffle_allbut_nouns_and_adj,
+                             shuffler.shuffle_within_trigrams, shuffler.shuffle_trigrams]
+
+        self.root_dir = root_dir
+        if not os.path.exists(root_dir):
+            print("Directory for COCO could not be found!")
+            if download:
+                print("Downloading COCO now.")
+                self.download()
+            else:
+                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
+        
+        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
+                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
+        filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
+        download_url(urls[split],root_dir)
+        
+        self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
+        self.image_preprocess = image_preprocess
+        self.image_root = root_dir
+        
+        self.test_cases = []
+        
+        for img_id, ann in tqdm(enumerate(self.annotation)):
+            for i, caption in enumerate(ann['caption']):
+                test_case = {}
+                test_case["image"] = ann["image"]
+                test_case["caption_options"] = [pre_caption(caption,max_words)]
+
+                for perturb_fn in perturb_functions:
+                    test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
+                self.test_cases.append(test_case)
+                                    
+    def __len__(self):
+        return len(self.test_cases)
+    
+    def __getitem__(self, index):  
+        test_case = self.test_cases[index]  
+        image_path = os.path.join(self.image_root, test_case["image"])       
+         
+        image = Image.open(image_path).convert('RGB')    
+        if self.image_preprocess is not None: 
+            image = self.image_preprocess(image)  
+        
+        item = edict({"image_options": [image], "caption_options": test_case["caption_options"]})
+        return item
+    
+    def download(self):
+        import subprocess
+        os.makedirs(self.root_dir, exist_ok=True)
+        #subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
+        #subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)
+        
+        subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
+        subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)
+        
+        subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
+        subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)
+        
+    
+    def evaluate_scores(self, scores):
+        if isinstance(scores, tuple):
+            scores_i2t = scores[0]
+            scores_t2i = scores[1].T # Make it N_ims x N_text
+        
+        else:
+            scores_t2i = scores
+            scores_i2t = scores
+        
+        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
+        correct_mask = (preds == 0)
+        records = [{"Precision@1": np.mean(correct_mask)}]
+        return records
+
+
+class Flickr30k_Order(Dataset):
+    def __init__(self, image_preprocess, split, root_dir=FLICKR_ROOT, max_words=30,
+                 *args, **kwargs):  
+        """
+        image_preprocess: image preprocessing function
+        split: 'val' or 'test'
+        root_dir: The directory of the flickr30k images. This should contain the `flickr30k-images` directory that \
+            contains all the images. 
+        """
+        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
+                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
+        filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
+        if not os.path.exists(root_dir):
+            print("Directory for Flickr30k could not be found!")
+            flickr_url = "https://forms.illinois.edu/sec/229675"
+            raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.")
+        
+        download_url(urls[split],root_dir)
+        
+        self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
+        self.image_preprocess = image_preprocess
+        self.root_dir = root_dir
+        
+        self.test_cases = []
+        
+        shuffler = TextShuffler()
+        perturb_functions = [shuffler.shuffle_nouns_and_adj, shuffler.shuffle_allbut_nouns_and_adj,
+                             shuffler.shuffle_within_trigrams, shuffler.shuffle_trigrams]
+        for img_id, ann in tqdm(enumerate(self.annotation)):
+            for i, caption in enumerate(ann['caption']):
+                test_case = {}
+                test_case["image"] = ann["image"]
+                test_case["caption_options"] = [pre_caption(caption,max_words)]
+
+                for perturb_fn in perturb_functions:
+                    test_case["caption_options"].append(pre_caption(perturb_fn(caption), max_words))
+                self.test_cases.append(test_case)
+                                
+    def __len__(self):
+        return len(self.test_cases)
+    
+    def __getitem__(self, index):  
+        test_case = self.test_cases[index]  
+        image_path = os.path.join(self.root_dir, test_case["image"])        
+        image = Image.open(image_path).convert('RGB')    
+        
+        if self.image_preprocess is not None: 
+            image = self.image_preprocess(image)  
+            
+        item = edict({"image_options": [image], "caption_options": test_case["caption_options"]})
+        return item
+    
+    def evaluate_scores(self, scores):
+        if isinstance(scores, tuple):
+            scores_i2t = scores[0]
+            scores_t2i = scores[1].T # Make it N_ims x N_text
+        else:
+            scores_t2i = scores
+            scores_i2t = scores
+        
+        preds = np.argmax(np.squeeze(scores_i2t, axis=1), axis=-1)
+        correct_mask = (preds == 0)
+        result_records = [{"Precision@1": np.mean(correct_mask)}]
+        return result_records
+
+
+def get_visual_genome_relation(image_preprocess, text_perturb_fn=None, image_perturb_fn=None, download=False):
+    return VG_Relation(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn, image_perturb_fn=image_perturb_fn, download=download)
+
+
+def get_visual_genome_attribution(image_preprocess, text_perturb_fn=None, image_perturb_fn=None, download=False):
+    return VG_Attribution(image_preprocess=image_preprocess, text_perturb_fn=text_perturb_fn,
+                   image_perturb_fn=image_perturb_fn, download=download)
+
+def get_coco_order(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=COCO_ROOT, split="test"):
+    return COCO_Order(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, 
+                            download=download)
+
+def get_flickr30k_order(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=FLICKR_ROOT, split="test"):
+    return Flickr30k_Order(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, 
+                            download=download)
+
diff --git a/multimodal/open_flamingo/eval/dataset_zoo/constants.py b/multimodal/open_flamingo/eval/dataset_zoo/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..0141315bf2af13e6a47c081d65facd00717cc838
--- /dev/null
+++ b/multimodal/open_flamingo/eval/dataset_zoo/constants.py
@@ -0,0 +1,3 @@
+ARO_ROOT = "~/.cache/prerelease_bow"
+COCO_ROOT = "~/.cache/coco/2014"
+FLICKR_ROOT = "~/.cache/flickr30k/images"
diff --git a/multimodal/open_flamingo/eval/dataset_zoo/perturbations.py b/multimodal/open_flamingo/eval/dataset_zoo/perturbations.py
new file mode 100644
index 0000000000000000000000000000000000000000..159743b71f34f04ac49a371b22df9b3d76f4b5c8
--- /dev/null
+++ b/multimodal/open_flamingo/eval/dataset_zoo/perturbations.py
@@ -0,0 +1,194 @@
+import torch
+import random
+import numpy as np
+from functools import partial
+import torch.nn.functional as nnf
+from torchvision import transforms as T
+
+# A lot of the approaches here are inspired from the wonderful paper from O'Connor and Andreas 2021.
+# https://github.com/lingo-mit/context-ablations
+
+def get_text_perturb_fn(text_perturb_fn):
+    if text_perturb_fn == "shuffle_nouns_and_adj":
+        return shuffle_nouns_and_adj
+    elif text_perturb_fn == "shuffle_allbut_nouns_and_adj":
+        return shuffle_allbut_nouns_and_adj
+    elif text_perturb_fn == "shuffle_within_trigrams":
+        return shuffle_within_trigrams
+    elif text_perturb_fn == "shuffle_all_words":
+        return shuffle_all_words
+    elif text_perturb_fn == "shuffle_trigrams":
+        return shuffle_trigrams
+    elif text_perturb_fn is None:
+        return None
+    else:
+        print("Unknown text perturbation function: {}, returning None".format(text_perturb_fn))
+        return None
+    
+    
+def get_image_perturb_fn(image_perturb_fn):
+    if image_perturb_fn == "shuffle_rows_4":
+        return partial(shuffle_rows, n_rows=4)
+    elif image_perturb_fn == "shuffle_patches_9":
+        return partial(shuffle_patches, n_ratio=3)
+    elif image_perturb_fn == "shuffle_cols_4":
+        return partial(shuffle_columns, n_cols=4)
+    elif image_perturb_fn is None:
+        return None
+    else:
+        print("Unknown image perturbation function: {}, returning None".format(image_perturb_fn))
+        return None
+    
+
+
+class TextShuffler:
+
+    def __init__(self):
+        import spacy
+        self.nlp = spacy.load("en_core_web_sm")
+
+    def shuffle_nouns_and_adj(self, ex):
+
+        doc = self.nlp(ex)
+        tokens = [token.text for token in doc]
+        text = np.array(tokens)
+        noun_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS']]
+        ## Finding adjectives
+        adjective_idx = [i for i, token in enumerate(doc) if token.tag_ in ['JJ', 'JJR', 'JJS']]
+        ## Shuffle the nouns of the text
+        text[noun_idx] = np.random.permutation(text[noun_idx])
+        ## Shuffle the adjectives of the text
+        text[adjective_idx] = np.random.permutation(text[adjective_idx])
+
+        return " ".join(text)
+
+    def shuffle_all_words(self, ex):
+        return " ".join(np.random.permutation(ex.split(" ")))
+
+
+    def shuffle_allbut_nouns_and_adj(self, ex):
+        doc = self.nlp(ex)
+        tokens = [token.text for token in doc]
+        text = np.array(tokens)
+        noun_adj_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS', 'JJ', 'JJR', 'JJS']]
+        ## Finding adjectives
+
+        else_idx = np.ones(text.shape[0])
+        else_idx[noun_adj_idx] = 0
+
+        else_idx = else_idx.astype(bool)
+        ## Shuffle everything that are nouns or adjectives
+        text[else_idx] = np.random.permutation(text[else_idx])
+        return " ".join(text)
+
+
+    def get_trigrams(self, sentence):
+        # Taken from https://github.com/lingo-mit/context-ablations/blob/478fb18a9f9680321f0d37dc999ea444e9287cc0/code/transformers/src/transformers/data/data_augmentation.py
+        trigrams = []
+        trigram = []
+        for i in range(len(sentence)):
+            trigram.append(sentence[i])
+            if i % 3 == 2:
+                trigrams.append(trigram[:])
+                trigram = []
+        if trigram:
+            trigrams.append(trigram)
+        return trigrams
+
+    def trigram_shuffle(self, sentence):
+        trigrams = self.get_trigrams(sentence)
+        for trigram in trigrams:
+            random.shuffle(trigram)
+        return " ".join([" ".join(trigram) for trigram in trigrams])
+
+
+    def shuffle_within_trigrams(self, ex):
+        import nltk
+        tokens = nltk.word_tokenize(ex)
+        shuffled_ex = self.trigram_shuffle(tokens)
+        return shuffled_ex
+
+
+    def shuffle_trigrams(self, ex):
+        import nltk
+        tokens = nltk.word_tokenize(ex)
+        trigrams = self.get_trigrams(tokens)
+        random.shuffle(trigrams)
+        shuffled_ex = " ".join([" ".join(trigram) for trigram in trigrams])
+        return shuffled_ex
+
+
+def _handle_image_4shuffle(x):
+    return_image = False
+    if not isinstance(x, torch.Tensor):
+        # print(f"x is not a tensor: {type(x)}. Trying to handle but fix this or I'll annoy you with this log")
+        t = torch.tensor(np.array(x)).unsqueeze(dim=0).float()
+        t = t.permute(0, 3, 1, 2)
+        return_image = True
+        return t, return_image
+    if len(x.shape) != 4:
+        #print("You did not send a tensor of shape NxCxWxH. Unsqueezing not but fix this or I'll annoy you with this log")
+        return x.unsqueeze(dim=0), return_image
+    else:
+        # Good boi
+        return x, return_image
+        
+
+def shuffle_rows(x, n_rows=7):
+    """
+    Shuffle the rows of the image tensor where each row has a size of 14 pixels.
+    Tensor is of shape N x C x W x H
+    """
+    x, return_image = _handle_image_4shuffle(x)
+    patch_size = x.shape[-2]//n_rows
+    u = nnf.unfold(x, kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
+    # permute the patches of each image in the batch
+    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
+    # fold the permuted patches back together
+    f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
+    
+    image = f.squeeze() # C W H
+    if return_image:
+        return T.ToPILImage()(image.type(torch.uint8))
+    else:
+        return image
+
+
+def shuffle_columns(x, n_cols=7):
+    """
+    Shuffle the columns of the image tensor where we'll have n_cols columns.
+    Tensor is of shape N x C x W x H
+    """
+    x, return_image = _handle_image_4shuffle(x)
+    patch_size = x.shape[-1]//n_cols
+    u = nnf.unfold(x, kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
+    # permute the patches of each image in the batch
+    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
+    # fold the permuted patches back together
+    f = nnf.fold(pu, x.shape[-2:], kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
+    image = f.squeeze() # C W H
+    if return_image:
+        return T.ToPILImage()(image.type(torch.uint8))
+    else:
+        return image
+
+
+
+def shuffle_patches(x, n_ratio=4):
+    """
+    Shuffle the rows of the image tensor where each row has a size of 14 pixels.
+    Tensor is of shape N x C x W x H
+    """
+    x, return_image = _handle_image_4shuffle(x)
+    patch_size_x = x.shape[-2]//n_ratio
+    patch_size_y = x.shape[-1]//n_ratio
+    u = nnf.unfold(x, kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
+    # permute the patches of each image in the batch
+    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
+    # fold the permuted patches back together
+    f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
+    image = f.squeeze() # C W H
+    if return_image:
+        return T.ToPILImage()(image.type(torch.uint8))
+    else:
+        return image
\ No newline at end of file
diff --git a/multimodal/open_flamingo/eval/dataset_zoo/retrieval.py b/multimodal/open_flamingo/eval/dataset_zoo/retrieval.py
new file mode 100644
index 0000000000000000000000000000000000000000..064a5924258a3f097e331af1d01fedb98eeae7be
--- /dev/null
+++ b/multimodal/open_flamingo/eval/dataset_zoo/retrieval.py
@@ -0,0 +1,266 @@
+import os
+import re
+import json
+import numpy as np
+
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset
+from torchvision.datasets.utils import download_url
+
+from .constants import COCO_ROOT, FLICKR_ROOT
+from .utils import AverageMeter
+
+
+def pre_caption(caption,max_words=50):
+    caption = re.sub(
+        r"([.!\"()*#:;~])",       
+        ' ',
+        caption.lower(),
+    )
+    caption = re.sub(
+        r"\s{2,}",
+        ' ',
+        caption,
+    )
+    caption = caption.rstrip('\n') 
+    caption = caption.strip(' ')
+
+    #truncate caption
+    caption_words = caption.split(' ')
+    if len(caption_words)>max_words:
+        caption = ' '.join(caption_words[:max_words])
+    
+    return caption
+
+
+class COCO_Retrieval(Dataset):
+    def __init__(self, image_preprocess=None, root_dir=COCO_ROOT, max_words=30, split="test",
+                 image_perturb_fn=None, download=False):  
+        """
+        COCO Retrieval Dataset.
+        image_preprocess: image preprocessing function
+        root_dir: The directory of the coco dataset. This directory should contain test2014 files.
+        max_words: Cropping the caption to max_words.
+        split: 'val' or 'test'
+        image_perturb_fn: image perturbation function for patch permutation experiments.
+        download: Whether to download the dataset if it does not exist.
+        """
+        self.root_dir = root_dir
+        if not os.path.exists(root_dir):
+            print("Directory for COCO could not be found!")
+            if download:
+                print("Downloading COCO now.")
+                self.download()
+            else:
+                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
+        
+        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
+                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
+        filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
+        download_url(urls[split],root_dir)
+        
+        
+        self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
+        self.image_preprocess = image_preprocess
+        self.image_perturb_fn = image_perturb_fn
+        self.image_root = root_dir
+        
+        self.text = []
+        self.image = []
+        self.txt2img = {}
+        self.img2txt = {}
+        
+        txt_id = 0
+        for img_id, ann in enumerate(self.annotation):
+            self.image.append(ann['image'])
+            self.img2txt[img_id] = []
+            for i, caption in enumerate(ann['caption']):
+                self.text.append(pre_caption(caption,max_words))
+                self.img2txt[img_id].append(txt_id)
+                self.txt2img[txt_id] = img_id
+                txt_id += 1
+                                    
+    def __len__(self):
+        return len(self.annotation)
+    
+    def __getitem__(self, index):    
+        image_path = os.path.join(self.image_root, self.annotation[index]['image'])        
+        image = Image.open(image_path).convert('RGB')    
+        
+        if self.image_preprocess is not None: 
+            image = self.image_preprocess(image)
+          
+        if self.image_perturb_fn is not None:
+            image = self.image_perturb_fn(image) 
+         
+        return {"image": image, "idx": index}
+    
+    def download(self):
+        import subprocess
+        os.makedirs(self.root_dir, exist_ok=True)
+        #subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
+        #subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)
+        
+        subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
+        subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)
+        
+        subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
+        subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)
+        
+    
+    def evaluate_scores(self, scores):
+        if isinstance(scores, tuple):
+            scores_i2t = scores[0]
+            scores_t2i = scores[1].T # Make it N_ims x N_text
+    
+        else:
+            scores_t2i = scores
+            scores_i2t = scores
+
+        print(f"COCO results across {scores_i2t.shape} samples. ")
+        prec_at_1 = AverageMeter()
+        prec_at_5 = AverageMeter()
+
+        # Text retrieval
+        tqdm_iterator = tqdm(range(len(self.img2txt)))
+        for i in tqdm_iterator:
+            top5_captions = np.argsort(scores_i2t[i])[-5:]
+            true_captions = self.img2txt[i]
+
+            prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0)
+            prec_at_5.update(len(set(true_captions) & set(top5_captions))>0)
+
+            tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}")
+
+        # Image Retrieval
+        image_prec_at_1 = AverageMeter()
+        image_prec_at_5 = AverageMeter()
+
+        tqdm_iterator = tqdm(range(len(self.txt2img)))
+        for i in tqdm_iterator:
+            top5_images = np.argsort(scores_t2i[:, i])[-5:]
+            true_image = self.txt2img[i]
+
+            image_prec_at_1.update(true_image in top5_images[-1:])
+            image_prec_at_5.update(true_image in top5_images)
+
+            tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}")
+
+        records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}]
+        return records
+
+
+
+class Flickr30k_Retrieval(Dataset):
+    def __init__(self, image_preprocess, split, root_dir=FLICKR_ROOT, max_words=30,
+                 image_perturb_fn=None, *args, **kwargs):  
+        '''
+        Flickr30k dataset for retrieval.
+        image_preprocess: image preprocessing function
+        root_dir: The directory of the coco dataset. This directory should contain test2014 files.
+        max_words: Cropping the caption to max_words.
+        split: 'val' or 'test'
+        image_perturb_fn: image perturbation function for patch permutation experiments.
+        download: Whether to download the dataset if it does not exist.
+        '''
+        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
+                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
+        filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
+        
+        if not os.path.exists(root_dir):
+            print("Directory for Flickr30k could not be found!")
+            flickr_url = "https://forms.illinois.edu/sec/229675"
+            raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.")
+        
+        download_url(urls[split],root_dir)
+        
+        self.annotation = json.load(open(os.path.join(root_dir,filenames[split]),'r'))
+        self.image_preprocess = image_preprocess
+        self.image_perturb_fn = image_perturb_fn
+        self.root_dir = root_dir
+        
+        self.text = []
+        self.image = []
+        self.txt2img = {}
+        self.img2txt = {}
+        
+        txt_id = 0
+        for img_id, ann in enumerate(self.annotation):
+            self.image.append(ann['image'])
+            self.img2txt[img_id] = []
+            for i, caption in enumerate(ann['caption']):
+                self.text.append(pre_caption(caption,max_words))
+                self.img2txt[img_id].append(txt_id)
+                self.txt2img[txt_id] = img_id
+                txt_id += 1
+                                    
+    def __len__(self):
+        return len(self.annotation)
+    
+    def __getitem__(self, index):    
+        image_path = os.path.join(self.root_dir, self.annotation[index]['image'])        
+        image = Image.open(image_path).convert('RGB')   
+        if self.image_preprocess is not None: 
+            image = self.image_preprocess(image)  
+        if self.image_perturb_fn is not None:
+            image = self.image_perturb_fn(image) 
+        
+        return {"image": image, "idx": index}
+    
+    def evaluate_scores(self, scores):
+        if isinstance(scores, tuple):
+            scores_i2t = scores[0]
+            scores_t2i = scores[1].T # Make it N_ims x N_text
+    
+        else:
+            scores_t2i = scores
+            scores_i2t = scores
+
+        print(f"Flickr30k Retrieval results across {scores_i2t.shape} samples. ")
+        prec_at_1 = AverageMeter()
+        prec_at_5 = AverageMeter()
+
+        # Text retrieval
+        tqdm_iterator = tqdm(range(len(self.img2txt)))
+        for i in tqdm_iterator:
+            top5_captions = np.argsort(scores_i2t[i])[-5:]
+            true_captions = self.img2txt[i]
+
+            prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0)
+            prec_at_5.update(len(set(true_captions) & set(top5_captions))>0)
+
+            tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}")
+
+        # Image Retrieval
+        image_prec_at_1 = AverageMeter()
+        image_prec_at_5 = AverageMeter()
+
+        tqdm_iterator = tqdm(range(len(self.txt2img)))
+        for i in tqdm_iterator:
+            top5_images = np.argsort(scores_t2i[:, i])[-5:]
+            true_image = self.txt2img[i]
+
+            image_prec_at_1.update(true_image in top5_images[-1:])
+            image_prec_at_5.update(true_image in top5_images)
+
+            tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}")
+
+        records = [{"ImagePrec@1": image_prec_at_1.avg, "ImagePrec@5": image_prec_at_5.avg, "TextPrec@1": prec_at_1.avg, "TextPrec@5": prec_at_5.avg}]
+        return records
+    
+    def download(self):
+        raise NotImplementedError("Flickr30k dataset is not available for download.")
+
+
+
+def get_coco_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=COCO_ROOT, split="test"):
+    dataset = COCO_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, 
+                            download=download)
+    return dataset
+
+
+def get_flickr30k_retrieval(image_preprocess, image_perturb_fn, text_perturb_fn, max_words=30, download=False, root_dir=FLICKR_ROOT, split="test"):
+    dataset = Flickr30k_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, 
+                            download=download)
+    return dataset
diff --git a/multimodal/open_flamingo/eval/dataset_zoo/utils.py b/multimodal/open_flamingo/eval/dataset_zoo/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e2ef979e68cc50959a6681b028c40005f79f724
--- /dev/null
+++ b/multimodal/open_flamingo/eval/dataset_zoo/utils.py
@@ -0,0 +1,15 @@
+class AverageMeter(object):
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
diff --git a/multimodal/open_flamingo/eval/eval_datasets.py b/multimodal/open_flamingo/eval/eval_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..672cf9e0c94935d0d4574f689e499a0b51b777b5
--- /dev/null
+++ b/multimodal/open_flamingo/eval/eval_datasets.py
@@ -0,0 +1,101 @@
+import json
+import os
+
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision.datasets import ImageFolder
+
+from open_flamingo.eval.imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
+
+
+class COCOFlickrDataset(Dataset):
+    def __init__(
+        self,
+        image_dir_path,
+        annotations_path,
+        is_flickr=False,
+    ):
+        self.image_dir_path = image_dir_path
+        self.annotations = json.load(open(annotations_path))["annotations"]
+        self.is_flickr = is_flickr
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def get_img_path(self, idx):
+        if self.is_flickr:
+            return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg"
+        else:
+            return f"{self.image_dir_path}/{self.annotations[idx]['image_id']:012d}.jpg"
+
+    def __getitem__(self, idx):
+        image = Image.open(self.get_img_path(idx))
+        caption = self.annotations[idx]["caption"]
+        return {
+            "image": image,
+            "caption": caption,
+            "image_id": self.annotations[idx]["image_id"],
+        }
+
+
+class VQADataset(Dataset):
+    def __init__(
+        self,
+        image_dir_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/train2014/",
+        question_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_OpenEnded_mscoco_train2014_questions.json",
+        annotations_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_mscoco_train2014_annotations.json",
+        vqa_dataset="vqa",
+    ):
+        self.questions = json.load(open(question_path, "r"))["questions"]
+        self.answers = json.load(open(annotations_path, "r"))["annotations"]
+        self.image_dir_path = image_dir_path
+        self.vqa_dataset = vqa_dataset
+
+    def __len__(self):
+        return len(self.questions)
+
+    def get_img_path(self, question):
+        if self.vqa_dataset == "vqa":
+            return os.path.join(
+                self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
+            )
+        elif self.vqa_dataset == "ok_vqa":
+            return os.path.join(
+                self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg"
+            )
+        else:
+            raise Exception(f"Unknown VQA dataset {self.vqa_dataset}")
+
+    def __getitem__(self, idx):
+        question = self.questions[idx]
+        answers = self.answers[idx]
+        img_path = self.get_img_path(question)
+        image = Image.open(img_path)
+        return {
+            "image": image,
+            "question": question["question"],
+            "answers": [a["answer"] for a in answers["answers"]],
+            "question_id": question["question_id"],
+        }
+
+
+class ImageNetDataset(ImageFolder):
+    """Class to represent the ImageNet1k dataset."""
+
+    def __init__(self, root, **kwargs):
+        super().__init__(root=root, **kwargs)
+
+    def __getitem__(self, idx):
+        sample, target = super().__getitem__(idx)
+        target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target]
+        return {
+            "image": sample,
+            "class_id": target,  # numeric ID of the ImageNet class
+            "class_name": target_label,  # human-readable name of ImageNet class
+        }
+
+
+if __name__ == "__main__":
+    gqa_dataset = GQADataset()
+    for sample in gqa_dataset:
+        print(sample)
diff --git a/multimodal/open_flamingo/eval/evaluate.py b/multimodal/open_flamingo/eval/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f3586649e669f423edeb803ac5eba5df283a9d
--- /dev/null
+++ b/multimodal/open_flamingo/eval/evaluate.py
@@ -0,0 +1,1435 @@
+import argparse
+import json
+from math import ceil
+import os
+import random
+import uuid
+from collections import defaultdict
+from typing import Callable
+import time
+import cv2
+import webdataset as wds
+from sklearn.metrics import recall_score, average_precision_score
+
+import more_itertools
+import numpy as np
+import torch
+from coco_metric import compute_cider, postprocess_captioning_generation
+from eval_datasets import VQADataset
+from tqdm import tqdm
+from collections import Counter
+
+from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
+from open_flamingo.eval.classification import (
+    compute_per_sample_probs,
+    compute_per_sample_loss,
+)
+from open_flamingo.eval.imagenet_utils import (
+    openai_imagenet_classnames,
+    IMAGENET_1K_CLASS_ID_TO_LABEL,
+)
+
+from open_flamingo.src.factory import create_model_and_transforms
+from PIL import Image
+from io import BytesIO
+import base64
+from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
+import string
+from open_flamingo.eval.task.reg import evaluate_reg
+from open_flamingo.eval.task.gqa import GQADataset
+from open_flamingo.eval.task.vl_checklist import evaluate_vlc
+from open_flamingo.eval.task.crepe import evaluate_crepe
+from open_flamingo.eval.task.caption import evaluate_coco_flickr
+from open_flamingo.eval.task.utils import is_correct, get_iou
+from open_flamingo.eval.task.cola import evaluate_cola
+from open_flamingo.eval.task.gqa import evaluate_gqa
+
+def expand2square(pil_img, background_color):
+    width, height = pil_img.size
+    if width == height:
+        return pil_img
+    elif width > height:
+        result = Image.new(pil_img.mode, (width, width), background_color)
+        result.paste(pil_img, (0, (width - height) // 2))
+        return result
+    else:
+        result = Image.new(pil_img.mode, (height, height), background_color)
+        result.paste(pil_img, ((height - width) // 2, 0))
+        return result
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
+parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
+parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
+parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
+parser.add_argument("--checkpoint_path", type=str, required=True)
+parser.add_argument(
+    "--results_file", type=str, default=None, help="JSON file to save results"
+)
+
+# Trial arguments
+parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
+parser.add_argument(
+    "--num_trials",
+    type=int,
+    default=1,
+    help="Number of trials to run for each shot using different demonstrations",
+)
+parser.add_argument(
+    "--trial_seeds",
+    nargs="+",
+    default=[0],
+    help="Seeds to use for each trial for picking demonstrations and eval sets",
+)
+parser.add_argument(
+    "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
+)
+
+parser.add_argument("--batch_size", type=int, default=8)
+
+# Per-dataset evaluation flags
+parser.add_argument(
+    "--eval_coco",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on COCO.",
+)
+parser.add_argument(
+    "--eval_vqav2",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on VQAV2.",
+)
+parser.add_argument(
+    "--eval_ok_vqa",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on OK-VQA.",
+)
+parser.add_argument(
+    "--eval_imagenet",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on ImageNet.",
+)
+
+parser.add_argument(
+    "--eval_flickr30",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on Flickr30.",
+)
+
+parser.add_argument(
+    "--eval_refcoco",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on RefCOCO.",
+)
+
+# Dataset arguments
+
+## Flickr30 Dataset
+parser.add_argument(
+    "--flickr_image_dir_path",
+    type=str,
+    help="Path to the flickr30/flickr30k_images directory.",
+    default=None,
+)
+parser.add_argument(
+    "--flickr_annotations_json_path",
+    type=str,
+    help="Path to the dataset_flickr30k_coco_style.json file.",
+    default=None,
+)
+
+## COCO Dataset
+parser.add_argument(
+    "--coco_image_dir_path",
+    type=str,
+    help="Path to the flickr30/flickr30k_images directory.",
+    default=None,
+)
+parser.add_argument(
+    "--coco_annotations_json_path",
+    type=str,
+    default=None,
+)
+
+## VQAV2 Dataset
+parser.add_argument(
+    "--vqav2_image_dir_path",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--vqav2_questions_json_path",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--vqav2_annotations_json_path",
+    type=str,
+    default=None,
+)
+
+## OK-VQA Dataset
+parser.add_argument(
+    "--ok_vqa_image_dir_path",
+    type=str,
+    help="Path to the vqav2/train2014 directory.",
+    default=None,
+)
+parser.add_argument(
+    "--ok_vqa_questions_json_path",
+    type=str,
+    help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
+    default=None,
+)
+parser.add_argument(
+    "--ok_vqa_annotations_json_path",
+    type=str,
+    help="Path to the v2_mscoco_train2014_annotations.json file.",
+    default=None,
+)
+
+## Imagenet dataset
+parser.add_argument("--imagenet_root", type=str, default="/tmp")
+
+## RefCOCO dataset
+parser.add_argument("--refcoco_tsvfile", type=str, default=None)
+
+parser.add_argument(
+    "--location_token_num",
+    default=1000,
+    type=int,
+)
+# distributed training
+parser.add_argument(
+    "--dist-url",
+    default="env://",
+    type=str,
+    help="url used to set up distributed training",
+)
+parser.add_argument(
+    "--dist-backend", default="nccl", type=str, help="distributed backend"
+)
+parser.add_argument(
+    "--horovod",
+    default=False,
+    action="store_true",
+    help="Use horovod for distributed training.",
+)
+parser.add_argument(
+    "--no-set-device-rank",
+    default=False,
+    action="store_true",
+    help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+)
+parser.add_argument(
+    "--dist",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--lora",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--lora_r",
+    default=16,
+    type=int,
+    required=False,
+)
+parser.add_argument(
+    "--legacy",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--special",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--id",
+    default=0,
+    type=int,
+    required=False,
+)
+
+parser.add_argument(
+    "--eval_gqa",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--use_sam",
+    default=None,
+    type=str,
+    required=False,
+)
+parser.add_argument(
+    "--add_visual_token",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--use_format_v2",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_aro",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_pisc",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_reg",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_vlc",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_crepe",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_cola",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--level",
+    default=4,
+    type=int,
+)
+parser.add_argument(
+    "--type",
+    default="swap",
+    type=str,
+)
+parser.add_argument(
+    "--choose_left_right",
+    default=False,
+    action="store_true",
+)
+
+
+class OKVQAPostProcess():
+    def __init__(self):
+        self._lemmatizer = None
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+        
+
+def main():
+    args = parser.parse_args()
+    if args.dist:
+        args.local_rank, args.rank, args.world_size = world_info_from_env()
+        print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
+        device_id = init_distributed_device(args)
+    else:
+        args.rank = 0
+        args.world_size = 1
+        print(f"rank: {args.rank} world_size: {args.world_size}")
+    
+    if "sam" in args.checkpoint_path:
+        args.use_sam = "vit_l"
+
+    args.add_visual_token = True
+    if "lora" in args.checkpoint_path:
+        args.lora = True
+
+
+    args.add_pe = False
+    args.add_box = True
+    args.relation = False
+    args.enhance_data = False
+    args.use_format_v2 = True
+
+
+
+    import hashlib
+    args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
+
+    # load model
+    flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
+        args.vision_encoder_path,
+        args.vision_encoder_pretrained,
+        args.lm_path,
+        args.lm_tokenizer_path,
+        location_token_num=args.location_token_num,
+        lora=args.lora,
+        lora_r=16,
+        use_sam=args.use_sam,
+        add_visual_token=args.add_visual_token,
+        use_format_v2=args.use_format_v2,
+        add_box=args.add_box,
+        add_pe=args.add_pe,
+        add_relation=args.relation,
+        enhance_data=args.enhance_data,
+    )
+    flamingo.use_format_v2 = args.use_format_v2
+    if args.special:
+        flamingo.special = True
+    else:
+        flamingo.special = False
+    if args.legacy:
+        flamingo.legacy = True
+        print("use legacy evaluation")
+    flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
+    flamingo.expr_name = args.checkpoint_path.split("/")[-2]
+    if args.rank == 0:
+        print("legacy", True if hasattr(flamingo, "legacy") else False)
+        print("step:", flamingo.step_num)
+        print("expr:", flamingo.expr_name)
+        print("use format v2:", flamingo.use_format_v2)
+        print(args)
+    checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
+    model_state_dict = {}
+    for key in checkpoint["model_state_dict"].keys():
+        model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
+    if "vision_encoder.logit_scale"in model_state_dict:
+        # previous checkpoint has some unnecessary weights
+        del model_state_dict["vision_encoder.logit_scale"]
+        del model_state_dict["vision_encoder.visual.proj"]
+        del model_state_dict["vision_encoder.visual.ln_post.weight"]
+        del model_state_dict["vision_encoder.visual.ln_post.bias"]
+    flamingo.load_state_dict(model_state_dict, strict=True)
+    results = defaultdict(list)
+    if args.eval_coco:
+        print("Evaluating on COCO...")
+        cider_score = evaluate_coco_flickr(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["coco"].append({"score": cider_score})
+
+    if args.eval_ok_vqa:
+        print("Evaluating on OK-VQA...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                ok_vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.ok_vqa_image_dir_path,
+                    questions_json_path=args.ok_vqa_questions_json_path,
+                    annotations_json_path=args.ok_vqa_annotations_json_path,
+                    vqa_dataset="ok_vqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["ok_vqa"].append(
+                {"shots": shot, "score": ok_vqa_score}
+            )
+
+    if args.eval_vqav2:
+        print("Evaluating on VQAv2...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.vqav2_image_dir_path,
+                    questions_json_path=args.vqav2_questions_json_path,
+                    annotations_json_path=args.vqav2_annotations_json_path,
+                    vqa_dataset="vqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["vqav2"].append(
+                {"shots": shot, "score": vqa_score}
+            )
+
+    if args.eval_gqa:
+        print("Evaluating on GQA...")
+        gqa_score = evaluate_gqa(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["gqa"].append(
+            {"score": gqa_score}
+        )
+
+    if args.eval_refcoco:
+        print("Evaluating on RefCOCO...")
+        refcoco_score = evaluate_refcoco(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["refcoco"].append(
+            {"score": refcoco_score}
+        )
+    if args.eval_aro:
+        print("Evaluating on ARO...")
+        aro_score = evaluate_aro(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+            choose_left_right=args.choose_left_right,
+        )
+        results["aro"].append(
+            {"score": aro_score}
+        )
+    if args.eval_pisc:
+        print("Evaluating on ARO...")
+        aro_score = evaluate_pisc(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["pisc"].append(
+            {"score": aro_score}
+        )
+    if args.eval_reg:
+        print("Evaluating on Referring Expression Generation...")
+        cider = evaluate_reg(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["reg"].append(
+            {"score": cider}
+        )
+    if args.eval_vlc:
+        print("Evaluating on VL-checklist...")
+        vlc_score = evaluate_vlc(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["vlc"].append(
+            {"score": vlc_score}
+        )
+    if args.eval_crepe:
+        print("Evaluating on CREPE...")
+        crepe_score = evaluate_crepe(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+            level=args.level,
+            type=args.type,
+        )
+        results["crepe"].append(
+            {"score": crepe_score}
+        )
+    if args.eval_cola:
+        print("Evaluating on COLA...")
+        cola_score = evaluate_cola(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["cola"].append(
+            {"score": cola_score}
+        )
+
+def prepare_batch_images(batch, image_processor):
+    batch_images = None
+    for b in batch:
+        b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        if batch_images is None:
+            batch_images = b_image
+        else:
+            batch_images = torch.cat([batch_images, b_image], dim=0)
+    return batch_images
+
+def get_outputs(
+    model,
+    batch_images,
+    attention_mask,
+    max_generation_length,
+    min_generation_length,
+    num_beams,
+    length_penalty,
+    input_ids,
+    image_start_index_list=None,
+    image_nums=None,
+    bad_words_ids=None,
+):
+    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+        outputs = model.generate(
+            batch_images,
+            input_ids,
+            attention_mask=attention_mask,
+            max_new_tokens=max_generation_length,
+            min_length=min_generation_length,
+            num_beams=num_beams,
+            length_penalty=length_penalty,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+            bad_words_ids=bad_words_ids,
+        )
+
+    outputs = outputs[:, len(input_ids[0]) :]
+    return outputs
+
+
+def evaluate_vqa(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    image_dir_path=None,
+    questions_json_path=None,
+    annotations_json_path=None,
+    vqa_dataset="vqa",
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    """
+    Evaluate a model on VQA datasets. Currently supports VQA v2.0.
+
+    Args:
+        model (nn.Module): model to evaluate
+        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
+        image_processor : image processor for the model
+        batch_size (int): batch size
+        image_dir_path (str): path to image directory
+        questions_json_path (str): path to questions json file
+        annotations_json_path (str): path to annotations json file
+        seed (int, optional): random seed. Defaults to 42.
+        max_generation_length (int, optional): max generation length. Defaults to 5.
+        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
+        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
+        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
+        query_set_size (int, optional): size of the query set. Defaults to 2048.
+        num_shots (int, optional): number of shots to use. Defaults to 8.
+        device (int, optional): device to use. Defaults to -1 (cpu).
+        num_workers (int, optional): number of workers to use. Defaults to 4.
+        vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
+    Returns:
+        float: accuracy score
+    """
+    if world_size > 1:
+        torch.distributed.barrier()
+    if vqa_dataset == "gqa":
+        eval_dataset = GQADataset()
+    else:
+        eval_dataset = VQADataset(
+            image_dir_path=image_dir_path,
+            question_path=questions_json_path,
+            annotations_path=annotations_json_path,
+            vqa_dataset=vqa_dataset,
+        )
+    postprocessor = OKVQAPostProcess()
+    try:
+        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    except:
+        pass
+    def get_prompt(sample):
+        return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
+        # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
+
+    model.eval().cuda()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    if "peft" in lang_encoder_name:
+        lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
+    predictions = []
+    tokenizer.padding_side = "left"
+    if world_size > 1:
+        torch.distributed.barrier()
+    this_tot = 0
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
+    )):
+        if ii % world_size != rank:
+            continue
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        batch_text = [get_prompt(s) for s in batch]
+        encodings = tokenizer(
+            batch_text,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"].cuda()
+        attention_mask = encodings["attention_mask"].cuda()
+        skip_special_tokens = True
+        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
+            if rank == 0:
+                tqdm.write("use legacy model")
+            for i in range(len(input_ids)):
+                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
+                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
+                input_ids[i, media_token_index - 1] = media_token_id
+                input_ids[i, media_token_index] = pad_token_id
+                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
+                input_ids[i, endofmedia_token_index] = bos_token_id
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        if "llama" in lang_encoder_name:
+            attention_mask[input_ids == 0] = 0
+        outputs = get_outputs(
+            model=model,
+            batch_images=batch_images,
+            attention_mask=attention_mask,
+            max_generation_length=10,
+            min_generation_length=1,
+            num_beams=5,
+            length_penalty=0,
+            input_ids=input_ids,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+        )
+        # postprocess begin
+        new_predictions = [
+            out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
+        ]
+        if vqa_dataset == "ok_vqa":
+            new_predictions = postprocessor._lemmatize(new_predictions)
+        if model.special:
+            for i in range(len(new_predictions)):
+                for answer, _ in Counter(batch[i]['answers']).most_common():
+                    if answer in new_predictions[i]:
+                        new_predictions[i] = answer
+                        break
+                    if "cant" in new_predictions[i] and "no" == answer:
+                        new_predictions[i] = answer
+                        break
+                    if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
+                        new_predictions[i] = answer
+                        break
+        
+        this_tot += 1
+        if rank == 0 and this_tot % 20 == 0:
+            for i in range(1):
+                tqdm.write("model output: " + new_predictions[i])
+
+        predictions.extend(
+            [
+                {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
+                for p, sample in zip(new_predictions, batch)
+            ]
+        )
+    with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps(predictions))
+    print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
+
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+            predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
+            os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+        print("num:", len(predictions))
+        # save the predictions to a temporary file
+        random_uuid = str(uuid.uuid4())
+        with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
+            f.write(json.dumps(predictions, indent=4))
+
+        if vqa_dataset == "gqa":
+            acc = compute_gqa_accuracy(predictions)
+        else:
+            acc = compute_vqa_accuracy(
+                f"{vqa_dataset}results_{random_uuid}.json",
+                questions_json_path,
+                annotations_json_path,
+                vqa_dataset=vqa_dataset,
+            )
+        print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
+        os.makedirs("eval_results", exist_ok=True)
+        with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(f"{vqa_dataset}results_{random_uuid}.json")
+    else:
+        time.sleep(5)
+        acc = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return acc
+
+
+def evaluate_refcoco(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    model.eval().cuda()
+    loc_token_ids = []
+    for i in range(1000):
+        loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
+    # all_ids = set(range(model.lang_encoder.lm_head.out_features))
+    # bad_words_ids = list(all_ids - set(loc_token_ids))
+    # bad_words_ids = [[b] for b in bad_words_ids]
+    # min_loc_token_id = min(loc_token_ids)
+    # max_loc_token_id = max(loc_token_ids)
+    total = 0
+    correct = 0
+    ious = []
+    if "refcocog" in tsvfile:
+        dataset_name = "refcocog"
+    elif "refcocoplus" in tsvfile:
+        dataset_name = "refcocoplus"
+    else:
+        dataset_name = "refcoco"
+    with open(tsvfile, "r") as f:
+        lines = f.readlines()
+        pbar = tqdm(lines, disable=(rank != 0))
+        for ii, line in enumerate(pbar):
+            if ii % world_size != rank:
+                continue
+            total += 1
+            line = line.rstrip()
+            uniq_id, image_id, text, region_coord, image = line.split("\t")
+
+            image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
+
+            gt_box = np.array(list(map(float, region_coord.split(","))))
+            width = image.width
+            height = image.height
+            image = image.resize((224, 224))
+            gt_box = gt_box / np.array([width, height, width, height]) * 224
+            batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+            text = text.rstrip('.').strip().replace('"', '').capitalize()
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text}<|#endofobject#|><|#visual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
+
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            # attention_mask[input_ids == prebox_token_id] = 0
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+
+            model.debug_id = 0
+            with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=None,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=None,
+                    add_box=False,
+                )
+            boxes = outputs["boxes"]
+            scores = outputs["scores"]
+            boxes = boxes[scores >= scores[0]*0.5]
+            scores = scores[scores >= scores[0]*0.5]
+
+            text = text.lower().strip()
+            if text.split(" ")[0] not in ["a", "an", "the", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "several", "some"]:
+                text = "a " + text
+            losses = []
+            for box, score in zip(boxes, scores):
+                this_prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>There is<|#object#|><|#previsual#|><|#prebox#|><|#object#|> {text}"]
+                encodings = tokenizer(
+                    this_prompt,
+                    padding="longest",
+                    truncation=True,
+                    return_tensors="pt",
+                    max_length=2000,
+                )
+                input_ids = encodings["input_ids"]
+                attention_mask = encodings["attention_mask"]
+                image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+                image_start_index_list = [[x] for x in image_start_index_list]
+                image_nums = [1] * len(input_ids)
+                vision_x = batch_images.cuda()
+                lang_x = input_ids.cuda()
+                attention_mask = attention_mask.cuda()
+                added_bbox_list = [torch.tensor(box / 224).cuda().unsqueeze(0).clamp(0, 0.99)]
+                labels = lang_x.clone()
+                start_idx = (lang_x == object_token_id).nonzero()[-1, -1]
+                labels[0, :start_idx+1] = -100
+                with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+                    outputs = model(
+                        vision_x=vision_x,
+                        lang_x=lang_x,
+                        attention_mask=attention_mask,
+                        labels=labels,
+                        image_nums=image_nums,
+                        image_start_index_list=image_start_index_list,
+                        added_bbox_list=added_bbox_list,
+                        add_box=True,
+                    )
+                    # print(tokenizer.decode(outputs.logits[0, start_idx].sort(descending=True).indices[:10]))
+                    loss = outputs.loss.detach().cpu()
+                    losses.append((loss.sum() / (loss != 0).sum()).item())
+            chosen_idx = np.array(losses).argmin()
+            pred_box = boxes[chosen_idx]
+            if chosen_idx != 0:
+                tqdm.write(f"{text}|{chosen_idx}|{scores[chosen_idx]}")
+            iou = get_iou(pred_box, gt_box)
+            if iou >= 0.5:
+                correct += 1
+            # else:
+            #     if rank == 0:
+            #         tqdm.write(text.rstrip('.').strip().lower())
+            #     open_cv_image = np.array(image)
+            #     # Convert RGB to BGR
+            #     open_cv_image = open_cv_image[:, :, ::-1].copy()
+            #     open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
+            #     open_cv_image = cv2.rectangle(open_cv_image, gt_box[:2].astype(int), gt_box[2:].astype(int), (0, 255, 0), 2)
+            #     cv2.imwrite(f"refcocog_result/{ii}_{iou}_{text}.jpg", open_cv_image)
+            pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
+            # open_cv_image = np.array(image)
+            # # Convert RGB to BGR 
+            # open_cv_image = open_cv_image[:, :, ::-1].copy() 
+            # for box, score in zip(boxes, scores):
+            #     open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
+            # cv2.imwrite("output.jpg", open_cv_image)
+            # print(boxes)
+            # print(scores)
+            # exit()
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+
+# def preprocess_visual_info(Text):
+#     text = Text.split(" ")
+#     for is_idx, t in enumerate(text):
+#         if t == "is":
+#             break
+#     the_idx = is_idx
+#     while text[the_idx] != "the":
+#         the_idx -= 1
+#     obj_A = " ".join(text[the_idx+1:is_idx])
+#     second_the_idx = len(text) - 1
+#     while text[second_the_idx] != "the":
+#         second_the_idx -= 1
+#     obj_B =  " ".join(text[second_the_idx+1:])
+#     visual_obj_A = f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
+#     visual_obj_B = f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
+#     Text = Text.replace(obj_A, f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
+#     Text = Text.replace(obj_B, f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
+#     return Text, obj_A, obj_B, visual_obj_A, visual_obj_B
+
+
+def preprocess_visual_info(Text):
+    text = Text.split(" ")
+    for is_idx, t in enumerate(text):
+        if t == "is":
+            break
+    the_idx = is_idx
+    while text[the_idx] != "the":
+        the_idx -= 1
+    obj_A = " ".join(text[the_idx+1:is_idx])
+    second_the_idx = len(text) - 1
+    while text[second_the_idx] != "the":
+        second_the_idx -= 1
+    obj_B = " ".join(text[second_the_idx+1:])
+    relation = " ".join(text[is_idx+1:second_the_idx])
+    visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
+    visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
+    Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
+    return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
+
+
+
+
+def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
+    assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
+    encodings = tokenizer(
+        prompt,
+        padding="longest",
+        truncation=True,
+        return_tensors="pt",
+        max_length=2000,
+    )
+    input_ids = encodings["input_ids"]
+    attention_mask = encodings["attention_mask"]
+    image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+    image_start_index_list = [[x] for x in image_start_index_list]
+    image_nums = [1] * len(input_ids)
+    vision_x = batch_images.cuda()
+    lang_x = input_ids.cuda()
+    attention_mask = attention_mask.cuda()
+
+    model.debug_id = 0
+    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+        outputs = model(
+            vision_x=vision_x,
+            lang_x=lang_x,
+            attention_mask=attention_mask,
+            labels=None,
+            image_nums=image_nums,
+            image_start_index_list=image_start_index_list,
+            added_bbox_list=visual_box_list,
+            add_box=visual_box_list is not None,
+            relations=None,
+            debug_mode=False,
+        )
+    boxes = outputs["boxes"]
+    scores = outputs["scores"]
+    if debug:
+        import pdb; pdb.set_trace()
+    if return_all:
+        return boxes, scores
+    if len(scores) == 0:
+        return None, None
+    else:
+        return boxes[scores.argmax()], scores.max()
+
+
+def evaluate_aro(
+    model,
+    tokenizer,
+    image_processor,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    add_visual=True,
+    subset=False,
+    choose_left_right=False,
+):
+    # os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
+    dataset_name = "aro"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    model.eval().cuda()
+    total = 0
+    n_top1 = 0
+    n_top5 = 0
+    from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
+    vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
+    if subset:
+        subset_idx = json.load(open("aro_subset.json"))
+        pbar = tqdm(subset_idx, disable=(rank != 0))
+    else:
+        pbar = tqdm(vgr_dataset, disable=(rank != 0))
+    for ii, sample in enumerate(pbar):
+        if subset:
+            ORI_IDX = int(sample)
+            sample = vgr_dataset[sample]
+        if ii % world_size != rank:
+            continue
+        image = sample["image_options"][0]
+        # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
+        image = image.resize((224, 224))
+
+        text = sample["caption_options"][1] # 1 is true caption
+        # text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
+        batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
+
+
+        first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
+        first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
+
+        if first_box is None:
+            text_A = "the " + obj_A
+            added_bbox_list = None
+        else:
+            text_A = visual_obj_A
+            added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
+
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
+        pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, 
+        prebox_token_id, return_all=True)
+
+        if pre_boxes is None:
+            pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
+            pre_scores = [1.0]
+
+        logits_list = []
+        # pre_boxes = [pre_boxes[0]]
+        # pre_scores = [pre_scores[0]]
+        for pre_box, pre_score in zip(pre_boxes, pre_scores):
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
+
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=512,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+            labels = lang_x.clone()
+            added_bbox_list = None
+            if add_visual:
+                added_bbox_list = []
+                if first_box is not None:
+                    added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
+                if pre_box is not None:
+                    added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
+            if added_bbox_list is not None and len(added_bbox_list) == 0:
+                added_bbox_list = None
+
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=labels,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=added_bbox_list,
+                    add_box=added_bbox_list is not None,
+                    relations=None,
+                )
+            logits_list.append([pre_score, outputs.logits])
+        pre_scores = np.array([x[0] for x in logits_list])
+        final_probs = 0.0
+        for score, (_, logits) in zip(pre_scores, logits_list):
+            final_probs += score * logits.softmax(-1)
+        assert input_ids.shape[:2] == final_probs.shape[:2]
+        _rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, obj_B, topk=5)
+        if is_top1:
+            n_top1 += 1
+        if is_top5:
+            n_top5 += 1
+        total += 1
+        pbar.set_description(f"acc@top1: {n_top1 / total:.4f} | acc@top5: {n_top5 / total:.4f} | {_rank}")
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, n_top1, n_top5]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        n_top1 = 0
+        n_top5 = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, n_top1_part, n_top5_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            n_top1 += n_top1_part
+            n_top5 += n_top5_part
+        acc_top1 = n_top1 / total
+        acc_top5 = n_top5 / total
+        print("acc_top1:", acc_top1, "acc_top5:", acc_top5, "total:", total)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc_top1}_{acc_top5}_{total}_{subset}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+def evaluate_pisc(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    add_visual=True,
+):
+    from open_flamingo.train.instruction_template import PISC_TEMPLATES
+    dataset_name = "pisc"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    model.train().cuda()
+
+    dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
+    pbar = tqdm(dataset, disable=(rank != 0))
+
+    rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
+    rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
+    gt = []
+    pred_scores = []
+    for III, sample in enumerate(pbar):
+        if III % world_size != rank:
+            continue
+        image_path, dataset, data = sample
+        image = Image.open(image_path)
+        size = image_processor.transforms[0].size
+        image = image.resize((size, size))
+        batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        boxA = data[0]
+        boxB = data[1]
+        gt_relation = data[2]
+        losses = []
+        for i_rel, option_rel in enumerate(rel_id_to_type):
+            text = PISC_TEMPLATES[0].format(relation=option_rel)
+            added_bbox = [
+                torch.tensor([boxA]).cuda(),
+                torch.tensor([boxB]).cuda(),
+            ]
+            caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
+            encodings = tokenizer(
+                caption,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+
+            labels = lang_x.clone()
+            labels[labels == tokenizer.pad_token_id] = -100
+            if add_visual:
+                # endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
+                # endofattr_next_token_index[1] += 1
+                # endofattr_next_token_id = labels[endofattr_next_token_index]
+                # </obj><visual><box></attr>NEXT_WORD
+                # </obj> predict NEXT_WORD
+                # <visual><box></attr> predict nothing
+                labels[labels == visual_token_id] = -100
+                labels[labels == box_token_id] = -100
+                labels[labels == endofattr_token_id] = -100
+                # labels[endofattr_next_token_index] = -100
+            labels[:, 0] = -100
+            answer_token_id = tokenizer(" Answer").input_ids[0]
+            answer_token_loc = (input_ids == answer_token_id).nonzero()
+            for batch_idx, idx in answer_token_loc:
+                labels[batch_idx][:idx+2] = -100
+
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=labels,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=added_bbox,
+                    add_box=added_bbox is not None,
+                )
+                loss_total = outputs.loss.reshape(labels.shape[0], -1)
+                loss = loss_total.sum() / (loss_total != 0).sum()
+                losses.append(loss.item())
+        pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
+        gt.append(rel_type_to_id[gt_relation])
+    gt = np.array(gt)
+    pred_scores = np.array(pred_scores)
+    pred = pred_scores.argmax(1)
+
+
+    print("total num:", len(gt))
+    recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
+    print("recalls:", recalls)
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([gt.tolist(), pred.tolist()]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        gt = []
+        pred = []
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            gt.extend(gt_part)
+            pred.extend(pred_part)
+        print("total num:", len(gt))
+        recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
+        print("recalls:", recalls)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
+            f.write(f"{gt}\n")
+            f.write(f"{pred}\n")
+            f.write(f"{recalls}\n")
+    score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/open_flamingo/eval/evaluate_debug.py b/multimodal/open_flamingo/eval/evaluate_debug.py
new file mode 100644
index 0000000000000000000000000000000000000000..989f280df613db0120ae7e73f0d57f3b785de653
--- /dev/null
+++ b/multimodal/open_flamingo/eval/evaluate_debug.py
@@ -0,0 +1,1159 @@
+import argparse
+import json
+from math import ceil
+import os
+import random
+import uuid
+from collections import defaultdict
+from typing import Callable
+import time
+import cv2
+
+import more_itertools
+import numpy as np
+import torch
+from coco_metric import compute_cider, postprocess_captioning_generation
+from eval_datasets import VQADataset, GQADataset
+from tqdm import tqdm
+from collections import Counter
+
+from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
+from open_flamingo.eval.classification import (
+    compute_per_sample_probs,
+    compute_per_sample_loss,
+)
+from open_flamingo.eval.imagenet_utils import (
+    openai_imagenet_classnames,
+    IMAGENET_1K_CLASS_ID_TO_LABEL,
+)
+
+from open_flamingo.src.factory import create_model_and_transforms
+from PIL import Image
+from io import BytesIO
+import base64
+from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
+import string
+from lavis.datasets.builders import load_dataset
+
+
+def get_iou(box1, box2):
+    # box1 and box2 should be in the format [x1, y1, x2, y2]
+    intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
+                   max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
+    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+    union = area_box1 + area_box2 - intersection
+    iou = intersection / union if union > 0 else 0
+    return iou
+
+def expand2square(pil_img, background_color):
+    width, height = pil_img.size
+    if width == height:
+        return pil_img
+    elif width > height:
+        result = Image.new(pil_img.mode, (width, width), background_color)
+        result.paste(pil_img, (0, (width - height) // 2))
+        return result
+    else:
+        result = Image.new(pil_img.mode, (height, height), background_color)
+        result.paste(pil_img, ((height - width) // 2, 0))
+        return result
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
+parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
+parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
+parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
+parser.add_argument("--checkpoint_path", type=str, required=True)
+parser.add_argument(
+    "--results_file", type=str, default=None, help="JSON file to save results"
+)
+
+# Trial arguments
+parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
+parser.add_argument(
+    "--num_trials",
+    type=int,
+    default=1,
+    help="Number of trials to run for each shot using different demonstrations",
+)
+parser.add_argument(
+    "--trial_seeds",
+    nargs="+",
+    default=[0],
+    help="Seeds to use for each trial for picking demonstrations and eval sets",
+)
+parser.add_argument(
+    "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
+)
+
+parser.add_argument("--batch_size", type=int, default=8)
+
+# Per-dataset evaluation flags
+parser.add_argument(
+    "--eval_coco",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on COCO.",
+)
+parser.add_argument(
+    "--eval_vqav2",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on VQAV2.",
+)
+parser.add_argument(
+    "--eval_ok_vqa",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on OK-VQA.",
+)
+parser.add_argument(
+    "--eval_imagenet",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on ImageNet.",
+)
+
+parser.add_argument(
+    "--eval_flickr30",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on Flickr30.",
+)
+
+parser.add_argument(
+    "--eval_refcoco",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on RefCOCO.",
+)
+
+# Dataset arguments
+
+## Flickr30 Dataset
+parser.add_argument(
+    "--flickr_image_dir_path",
+    type=str,
+    help="Path to the flickr30/flickr30k_images directory.",
+    default=None,
+)
+parser.add_argument(
+    "--flickr_annotations_json_path",
+    type=str,
+    help="Path to the dataset_flickr30k_coco_style.json file.",
+    default=None,
+)
+
+## COCO Dataset
+parser.add_argument(
+    "--coco_image_dir_path",
+    type=str,
+    help="Path to the flickr30/flickr30k_images directory.",
+    default=None,
+)
+parser.add_argument(
+    "--coco_annotations_json_path",
+    type=str,
+    default=None,
+)
+
+## VQAV2 Dataset
+parser.add_argument(
+    "--vqav2_image_dir_path",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--vqav2_questions_json_path",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--vqav2_annotations_json_path",
+    type=str,
+    default=None,
+)
+
+## OK-VQA Dataset
+parser.add_argument(
+    "--ok_vqa_image_dir_path",
+    type=str,
+    help="Path to the vqav2/train2014 directory.",
+    default=None,
+)
+parser.add_argument(
+    "--ok_vqa_questions_json_path",
+    type=str,
+    help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
+    default=None,
+)
+parser.add_argument(
+    "--ok_vqa_annotations_json_path",
+    type=str,
+    help="Path to the v2_mscoco_train2014_annotations.json file.",
+    default=None,
+)
+
+## Imagenet dataset
+parser.add_argument("--imagenet_root", type=str, default="/tmp")
+
+## RefCOCO dataset
+parser.add_argument("--refcoco_tsvfile", type=str, default=None)
+
+parser.add_argument(
+    "--location_token_num",
+    default=1000,
+    type=int,
+)
+# distributed training
+parser.add_argument(
+    "--dist-url",
+    default="env://",
+    type=str,
+    help="url used to set up distributed training",
+)
+parser.add_argument(
+    "--dist-backend", default="nccl", type=str, help="distributed backend"
+)
+parser.add_argument(
+    "--horovod",
+    default=False,
+    action="store_true",
+    help="Use horovod for distributed training.",
+)
+parser.add_argument(
+    "--no-set-device-rank",
+    default=False,
+    action="store_true",
+    help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+)
+parser.add_argument(
+    "--dist",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--lora",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--lora_r",
+    default=16,
+    type=int,
+    required=False,
+)
+parser.add_argument(
+    "--legacy",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--special",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--id",
+    default=0,
+    type=int,
+    required=False,
+)
+
+parser.add_argument(
+    "--eval_gqa",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--use_sam",
+    default=None,
+    type=str,
+    required=False,
+)
+parser.add_argument(
+    "--add_visual_token",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--use_format_v2",
+    default=False,
+    action="store_true",
+)
+
+
+class OKVQAPostProcess():
+    def __init__(self):
+        self._lemmatizer = None
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+        
+
+def main():
+    args = parser.parse_args()
+    if args.dist:
+        args.local_rank, args.rank, args.world_size = world_info_from_env()
+        print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
+        device_id = init_distributed_device(args)
+    else:
+        args.rank = 0
+        args.world_size = 1
+        print(f"rank: {args.rank} world_size: {args.world_size}")
+    
+    if "sam" in args.checkpoint_path:
+        args.use_sam = "vit_l"
+
+    args.add_visual_token = True
+    if "lora" in args.checkpoint_path:
+        args.lora = True
+
+
+    args.add_pe = False
+    args.add_box = False
+    args.relation = False
+    if "debug" in args.checkpoint_path:
+        # args.add_pe = True
+        args.add_box = True
+    if "box" in args.checkpoint_path:
+        args.add_box = True
+    if "pe" in args.checkpoint_path:
+        args.add_pe = True
+    if "rel" in args.checkpoint_path:
+        args.relation = True
+        args.add_pe = False
+    if "previsual" in args.checkpoint_path:
+        args.use_format_v2 = True
+        args.relation = False
+
+
+
+    # load model
+    flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
+        args.vision_encoder_path,
+        args.vision_encoder_pretrained,
+        args.lm_path,
+        args.lm_tokenizer_path,
+        location_token_num=args.location_token_num,
+        lora=args.lora,
+        lora_r=16,
+        use_sam=args.use_sam,
+        add_visual_token=args.add_visual_token,
+        use_format_v2=args.use_format_v2,
+        add_box=args.add_box,
+        add_pe=args.add_pe,
+        add_relation=args.relation,
+    )
+    flamingo.use_format_v2 = args.use_format_v2
+    if args.special:
+        flamingo.special = True
+    else:
+        flamingo.special = False
+    if args.legacy:
+        flamingo.legacy = True
+        print("use legacy evaluation")
+    flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
+    flamingo.expr_name = args.checkpoint_path.split("/")[-2]
+    if args.rank == 0:
+        print("legacy", True if hasattr(flamingo, "legacy") else False)
+        print("step:", flamingo.step_num)
+        print("expr:", flamingo.expr_name)
+        print("use format v2:", flamingo.use_format_v2)
+        print(args)
+    checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
+    model_state_dict = {}
+    for key in checkpoint["model_state_dict"].keys():
+        model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
+    if "vision_encoder.logit_scale"in model_state_dict:
+        # previous checkpoint has some unnecessary weights
+        del model_state_dict["vision_encoder.logit_scale"]
+        del model_state_dict["vision_encoder.visual.proj"]
+        del model_state_dict["vision_encoder.visual.ln_post.weight"]
+        del model_state_dict["vision_encoder.visual.ln_post.bias"]
+    flamingo.load_state_dict(model_state_dict, strict=True)
+    results = defaultdict(list)
+    if args.eval_coco:
+        print("Evaluating on COCO...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                cider_score = evaluate_coco_flickr(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.coco_image_dir_path,
+                    annotations_json_path=args.coco_annotations_json_path,
+                    device=args.device,
+                    seed=seed,
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+                print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
+                scores.append(cider_score)
+            print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
+            results["coco"].append(
+                {"shots": shot, "trials": scores, "mean": np.mean(scores)}
+            )
+
+    if args.eval_ok_vqa:
+        print("Evaluating on OK-VQA...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                ok_vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.ok_vqa_image_dir_path,
+                    questions_json_path=args.ok_vqa_questions_json_path,
+                    annotations_json_path=args.ok_vqa_annotations_json_path,
+                    vqa_dataset="ok_vqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["ok_vqa"].append(
+                {"shots": shot, "score": ok_vqa_score}
+            )
+
+    if args.eval_vqav2:
+        print("Evaluating on VQAv2...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.vqav2_image_dir_path,
+                    questions_json_path=args.vqav2_questions_json_path,
+                    annotations_json_path=args.vqav2_annotations_json_path,
+                    vqa_dataset="vqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["vqav2"].append(
+                {"shots": shot, "score": vqa_score}
+            )
+
+    if args.eval_gqa:
+        print("Evaluating on GQA...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    vqa_dataset="gqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["gqa"].append(
+                {"shots": shot, "score": vqa_score}
+            )
+
+    if args.eval_imagenet:
+        print("Evaluating on ImageNet...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                imagenet_score = evaluate_imagenet(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    num_samples=args.num_samples,
+                    num_shots=shot,
+                    device=args.device,
+                    seed=seed,
+                    imagenet_root=args.imagenet_root,
+                )
+                print(
+                    f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
+                )
+                scores.append(imagenet_score)
+            print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
+            results["imagenet"].append(
+                {"shots": shot, "trials": scores, "mean": np.mean(scores)}
+            )
+
+    if args.eval_refcoco:
+        print("Evaluating on RefCOCO...")
+        refcoco_score = evaluate_refcoco(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["refcoco"].append(
+            {"score": refcoco_score}
+        )
+
+def prepare_batch_images(batch, image_processor):
+    batch_images = None
+    for b in batch:
+        b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        if batch_images is None:
+            batch_images = b_image
+        else:
+            batch_images = torch.cat([batch_images, b_image], dim=0)
+    return batch_images
+
+def get_outputs(
+    model,
+    batch_images,
+    attention_mask,
+    max_generation_length,
+    min_generation_length,
+    num_beams,
+    length_penalty,
+    input_ids,
+    image_start_index_list=None,
+    image_nums=None,
+    bad_words_ids=None,
+):
+    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+        outputs = model.generate(
+            batch_images,
+            input_ids,
+            attention_mask=attention_mask,
+            max_new_tokens=max_generation_length,
+            min_length=min_generation_length,
+            num_beams=num_beams,
+            length_penalty=length_penalty,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+            bad_words_ids=bad_words_ids,
+        )
+
+    outputs = outputs[:, len(input_ids[0]) :]
+    return outputs
+
+
+def evaluate_coco_flickr(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    image_dir_path,
+    annotations_json_path,
+    seed=42,
+    max_generation_length=20,
+    num_beams=1,
+    length_penalty=-2.0,
+    device=-1,
+    is_flickr=False,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    """Evaluate a model on COCO dataset.
+
+    Args:
+        model (nn.Module): model to evaluate
+        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
+        image_processor : image processor for the model
+        batch_size (int): batch size
+        image_dir_path (str, optional): path to the directory containing the images.
+        annotations_json_path (str, optional): path to the json file containing the annotations.
+        seed (int, optional): seed for random number generator. Defaults to 42.
+        max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
+        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
+        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
+        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
+        query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
+        num_shots (int, optional): number of in-context samples to use. Defaults to 8.
+        device (int, optional): device to use. Defaults to -1.
+        num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
+        is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
+
+    Returns:
+        float: CIDEr score
+
+    """
+    # eval_dataset = COCOFlickrDataset(
+    #     image_dir_path=image_dir_path,
+    #     annotations_path=annotations_json_path,
+    #     is_flickr=is_flickr,
+    # )
+    coco_dataset = load_dataset("coco_caption")
+    eval_dataset = coco_dataset["test"]
+
+
+    model.eval().cuda()
+    predictions = defaultdict()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    # if "peft" in lang_encoder_name:
+        # lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
+    try:
+        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    except:
+        pass
+
+    def get_prompt(sample):
+        return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
+
+    tokenizer.padding_side = "left"
+    cnt = 0
+    if world_size > 1:
+        torch.distributed.barrier()
+    desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
+    )):
+        if ii % world_size != rank:
+            continue
+        cnt += len(batch)
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        batch_text = [get_prompt(s) for s in batch]
+        encodings = tokenizer(
+            batch_text,
+            padding="longest",
+            truncation=True,
+            return_tensors="pt",
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"].cuda()
+        attention_mask = encodings["attention_mask"].cuda()
+        skip_special_tokens = False
+        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
+            if rank == 0:
+                tqdm.write("use legacy model")
+            skip_special_tokens = True
+            for i in range(len(input_ids)):
+                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
+                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
+                input_ids[i, media_token_index - 1] = media_token_id
+                input_ids[i, media_token_index] = pad_token_id
+                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
+                input_ids[i, endofmedia_token_index] = bos_token_id
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        if "llama" in lang_encoder_name:
+            attention_mask[input_ids == 0] = 0
+        outputs = get_outputs(
+            model=model,
+            batch_images=batch_images,
+            attention_mask=attention_mask,
+            max_generation_length=30,
+            min_generation_length=8,
+            num_beams=5,
+            length_penalty=0,
+            input_ids=input_ids,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+        )
+        new_predictions = [
+            postprocess_captioning_generation(out).replace('"', "")
+            for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
+        ]
+        # if rank == 0:
+        #     tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
+
+        for i, sample in enumerate(batch):
+            predictions[int(sample["image_id"])] = {
+                "caption": new_predictions[i],
+            }
+    results_path = (
+        f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
+        if is_flickr
+        else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
+    )
+    with open(results_path, "w") as f:
+        f.write(
+            json.dumps(
+                [
+                    {"image_id": k, "caption": predictions[k]["caption"]}
+                    for k in predictions
+                ],
+                indent=2,
+            )
+        )
+    print("save to", results_path)
+    del predictions
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            part_results_path = (
+                f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
+                if is_flickr
+                else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
+            )
+            print("load", part_results_path)
+            predictions.extend(json.load(open(part_results_path)))
+            os.remove(part_results_path)
+        print("num:", len(predictions))
+        results_path = (
+            f"flickrresults_{lang_encoder_name}.json"
+            if is_flickr
+            else f"cocoresults_{lang_encoder_name}.json"
+        )
+        json.dump(predictions, open(results_path, "w"), indent=2)
+
+        metrics = compute_cider(
+            result_path=results_path,
+            annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
+        )
+        os.makedirs("eval_results", exist_ok=True)
+        acc = metrics["CIDEr"]
+        with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(results_path)
+    else:
+        metrics = {}
+        metrics["CIDEr"] = 0.0
+
+    return metrics["CIDEr"]
+
+
+def evaluate_vqa(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    image_dir_path=None,
+    questions_json_path=None,
+    annotations_json_path=None,
+    vqa_dataset="vqa",
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    """
+    Evaluate a model on VQA datasets. Currently supports VQA v2.0.
+
+    Args:
+        model (nn.Module): model to evaluate
+        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
+        image_processor : image processor for the model
+        batch_size (int): batch size
+        image_dir_path (str): path to image directory
+        questions_json_path (str): path to questions json file
+        annotations_json_path (str): path to annotations json file
+        seed (int, optional): random seed. Defaults to 42.
+        max_generation_length (int, optional): max generation length. Defaults to 5.
+        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
+        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
+        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
+        query_set_size (int, optional): size of the query set. Defaults to 2048.
+        num_shots (int, optional): number of shots to use. Defaults to 8.
+        device (int, optional): device to use. Defaults to -1 (cpu).
+        num_workers (int, optional): number of workers to use. Defaults to 4.
+        vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
+    Returns:
+        float: accuracy score
+    """
+    if world_size > 1:
+        torch.distributed.barrier()
+    if vqa_dataset == "gqa":
+        eval_dataset = GQADataset()
+    else:
+        eval_dataset = VQADataset(
+            image_dir_path=image_dir_path,
+            question_path=questions_json_path,
+            annotations_path=annotations_json_path,
+            vqa_dataset=vqa_dataset,
+        )
+    postprocessor = OKVQAPostProcess()
+    try:
+        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    except:
+        pass
+    def get_prompt(sample):
+        return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
+        # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
+
+    model.eval().cuda()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    if "peft" in lang_encoder_name:
+        lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
+    predictions = []
+    tokenizer.padding_side = "left"
+    if world_size > 1:
+        torch.distributed.barrier()
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
+    )):
+        if ii % world_size != rank:
+            continue
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        batch_text = [get_prompt(s) for s in batch]
+        encodings = tokenizer(
+            batch_text,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"].cuda()
+        attention_mask = encodings["attention_mask"].cuda()
+        skip_special_tokens = True
+        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
+            if rank == 0:
+                tqdm.write("use legacy model")
+            for i in range(len(input_ids)):
+                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
+                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
+                input_ids[i, media_token_index - 1] = media_token_id
+                input_ids[i, media_token_index] = pad_token_id
+                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
+                input_ids[i, endofmedia_token_index] = bos_token_id
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        if "llama" in lang_encoder_name:
+            attention_mask[input_ids == 0] = 0
+        outputs = get_outputs(
+            model=model,
+            batch_images=batch_images,
+            attention_mask=attention_mask,
+            max_generation_length=10,
+            min_generation_length=1,
+            num_beams=5,
+            length_penalty=0,
+            input_ids=input_ids,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+        )
+        # postprocess begin
+        new_predictions = [
+            out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
+        ]
+        if vqa_dataset == "ok_vqa":
+            new_predictions = postprocessor._lemmatize(new_predictions)
+        if model.special:
+            for i in range(len(new_predictions)):
+                for answer, _ in Counter(batch[i]['answers']).most_common():
+                    if answer in new_predictions[i]:
+                        new_predictions[i] = answer
+                        break
+                    if "cant" in new_predictions[i] and "no" == answer:
+                        new_predictions[i] = answer
+                        break
+                    if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
+                        new_predictions[i] = answer
+                        break
+
+        # if rank == 0:
+        #     tqdm.write(f"{image_nums} {image_start_index_list}")
+        #     for i in range(1):
+        #         tqdm.write(f"ID: {batch[i]['question_id']} | gt QA: {batch[i]['question']} {Counter(batch[i]['answers']).most_common()}")
+        #         tqdm.write("prompt: " + tokenizer.decode(input_ids[i]))
+        #         tqdm.write("model output: " + new_predictions[i])
+
+        predictions.extend(
+            [
+                {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
+                for p, sample in zip(new_predictions, batch)
+            ]
+        )
+    with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps(predictions))
+    print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
+
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+            predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
+            os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+        print("num:", len(predictions))
+        # save the predictions to a temporary file
+        random_uuid = str(uuid.uuid4())
+        with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
+            f.write(json.dumps(predictions, indent=4))
+
+        if vqa_dataset == "gqa":
+            acc = compute_gqa_accuracy(predictions)
+        else:
+            acc = compute_vqa_accuracy(
+                f"{vqa_dataset}results_{random_uuid}.json",
+                questions_json_path,
+                annotations_json_path,
+                vqa_dataset=vqa_dataset,
+            )
+        print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
+        os.makedirs("eval_results", exist_ok=True)
+        with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(f"{vqa_dataset}results_{random_uuid}.json")
+    else:
+        time.sleep(5)
+        acc = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return acc
+
+
+def evaluate_refcoco(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    model.eval().cuda()
+    loc_token_ids = []
+    for i in range(1000):
+        loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    total = 0
+    correct = 0
+    ious = []
+    if "refcocog" in tsvfile:
+        dataset_name = "refcocog"
+    elif "refcocoplus" in tsvfile:
+        dataset_name = "refcocoplus"
+    else:
+        dataset_name = "refcoco"
+    with open(tsvfile, "r") as f:
+        lines = f.readlines()
+        pbar = tqdm(lines, disable=(rank != 0))
+        for ii, line in enumerate(pbar):
+            if ii % world_size != rank:
+                continue
+            total += 1
+            line = line.rstrip()
+            uniq_id, image_id, text, region_coord, image = line.split("\t")
+
+            # image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
+            # image2 = Image.open("yolo.png").convert("RGB")
+            # image1 = image1.resize((224, 224))
+            # image2 = image2.resize((224, 224))
+            # images = [image1, image2]
+
+            # gt_box = np.array(list(map(float, region_coord.split(","))))
+            # width = image.width
+            # height = image.height
+            # gt_box /= np.array([width, height, width, height])
+            # batch_images = [image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) for image in images]
+            # batch_images = torch.cat(batch_images, dim=0)
+            # image = Image.open("yolo_test.png").convert("RGB")
+            image = Image.open("example.png").convert("RGB")
+            image = image.resize((224, 224))
+            batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text.rstrip('.')}<|#visual#|>"]
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#endofattr#|>man<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|> is sitting on<|#object#|><|#previsual#|>"]
+            # prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|>man<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|> is sitting on<|#object#|><|#previsual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
+
+
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [image_start_index_list]
+            image_nums = [1]
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+            print(image_start_index_list, image_nums)
+
+            model.debug_id = 0
+            # outputs = get_outputs(
+            #     model=model,
+            #     batch_images=vision_x,
+            #     attention_mask=attention_mask,
+            #     max_generation_length=20,
+            #     min_generation_length=8,
+            #     num_beams=5,
+            #     length_penalty=0,
+            #     input_ids=lang_x,
+            #     image_start_index_list=image_start_index_list,
+            #     image_nums=image_nums,
+            # )
+            # print(tokenizer.decode(outputs[0]))
+            # exit()
+
+            prebox = [93, 20, 155, 172] # man
+            # prebox = [32, 82, 89, 213] # dog
+            # prebox = [34, 49, 166, 164] # bike
+            with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=None,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=[torch.tensor(prebox).cuda().unsqueeze(0) / 224],
+                    add_box=True,
+                    debug_mode=True,
+                )
+            
+            boxes = outputs["boxes"]
+            scores = outputs["scores"]
+            box = boxes[scores.argmax()]
+            open_cv_image = np.array(image)
+            # Convert RGB to BGR 
+            open_cv_image = open_cv_image[:, :, ::-1].copy() 
+            open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
+            open_cv_image = cv2.rectangle(open_cv_image, prebox[:2], prebox[2:], (0, 0, 255), 2)
+            cv2.imwrite(f"output2.jpg", open_cv_image)
+            print(box)
+            print(prebox)
+            exit()
+
+            # force_words = ["man", "table"]
+            # force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
+
+
+            # sequences, hidden_states_for_each_step = get_outputs(
+            #     model=model,
+            #     batch_images=vision_x,
+            #     attention_mask=attention_mask,
+            #     max_generation_length=20,
+            #     min_generation_length=8,
+            #     num_beams=5,
+            #     length_penalty=0,
+            #     input_ids=lang_x,
+            #     image_start_index_list=image_start_index_list,
+            #     image_nums=image_nums,
+            #     force_words_ids=force_words_ids,
+            # )
+            # sequence = sequences[0]
+            # print(tokenizer.decode(sequence))
+            # for i, token in enumerate(sequence):
+            #     if token == model.visual_token_id:
+            #         print(tokenizer.decode(sequence[:i+1]))
+            #         if hasattr(model, "debug_id"):
+            #             model.debug_id += 1
+            #         else:
+            #             model.debug_id = 0
+            #         this_lang_x = torch.hstack([lang_x[0], sequence[:i+1]]).unsqueeze(0)
+            #         this_attention_mask = torch.ones_like(this_lang_x).cuda()
+            #         with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+            #             _ = model(
+            #                 vision_x=vision_x,
+            #                 lang_x=this_lang_x,
+            #                 attention_mask=this_attention_mask,
+            #                 labels=None,
+            #                 image_nums=image_nums,
+            #                 image_start_index_list=image_start_index_list,
+            #                 added_bbox_list=None,
+            #             )
+            # exit()
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/open_flamingo/eval/evaluate_find_showcase.py b/multimodal/open_flamingo/eval/evaluate_find_showcase.py
new file mode 100644
index 0000000000000000000000000000000000000000..a16d1f0c9a0d301f3577d63e4c83662a373a6bed
--- /dev/null
+++ b/multimodal/open_flamingo/eval/evaluate_find_showcase.py
@@ -0,0 +1,1700 @@
+import argparse
+import json
+from math import ceil
+import os
+import random
+import uuid
+from collections import defaultdict
+from typing import Callable
+import time
+import cv2
+import webdataset as wds
+from sklearn.metrics import recall_score, average_precision_score
+
+import more_itertools
+import numpy as np
+import torch
+from coco_metric import compute_cider, postprocess_captioning_generation
+from eval_datasets import VQADataset
+from tqdm import tqdm
+from collections import Counter
+
+from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
+from open_flamingo.eval.classification import (
+    compute_per_sample_probs,
+    compute_per_sample_loss,
+)
+from open_flamingo.eval.imagenet_utils import (
+    openai_imagenet_classnames,
+    IMAGENET_1K_CLASS_ID_TO_LABEL,
+)
+
+from open_flamingo.src.factory import create_model_and_transforms
+from PIL import Image
+from io import BytesIO
+import base64
+from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
+import string
+from lavis.datasets.builders import load_dataset
+from open_flamingo.eval.task.reg import evaluate_reg
+from open_flamingo.eval.task.gqa import GQADataset
+from open_flamingo.eval.task.vl_checklist import evaluate_vlc
+from open_flamingo.eval.task.crepe import evaluate_crepe
+
+def get_iou(box1, box2):
+    # box1 and box2 should be in the format [x1, y1, x2, y2]
+    intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
+                   max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
+    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+    union = area_box1 + area_box2 - intersection
+    iou = intersection / union if union > 0 else 0
+    return iou
+
+def expand2square(pil_img, background_color):
+    width, height = pil_img.size
+    if width == height:
+        return pil_img
+    elif width > height:
+        result = Image.new(pil_img.mode, (width, width), background_color)
+        result.paste(pil_img, (0, (width - height) // 2))
+        return result
+    else:
+        result = Image.new(pil_img.mode, (height, height), background_color)
+        result.paste(pil_img, ((height - width) // 2, 0))
+        return result
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
+parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
+parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
+parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
+parser.add_argument("--checkpoint_path", type=str, required=True)
+parser.add_argument(
+    "--results_file", type=str, default=None, help="JSON file to save results"
+)
+
+# Trial arguments
+parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
+parser.add_argument(
+    "--num_trials",
+    type=int,
+    default=1,
+    help="Number of trials to run for each shot using different demonstrations",
+)
+parser.add_argument(
+    "--trial_seeds",
+    nargs="+",
+    default=[0],
+    help="Seeds to use for each trial for picking demonstrations and eval sets",
+)
+parser.add_argument(
+    "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
+)
+
+parser.add_argument("--batch_size", type=int, default=8)
+
+# Per-dataset evaluation flags
+parser.add_argument(
+    "--eval_coco",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on COCO.",
+)
+parser.add_argument(
+    "--eval_vqav2",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on VQAV2.",
+)
+parser.add_argument(
+    "--eval_ok_vqa",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on OK-VQA.",
+)
+parser.add_argument(
+    "--eval_imagenet",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on ImageNet.",
+)
+
+parser.add_argument(
+    "--eval_flickr30",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on Flickr30.",
+)
+
+parser.add_argument(
+    "--eval_refcoco",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on RefCOCO.",
+)
+
+# Dataset arguments
+
+## Flickr30 Dataset
+parser.add_argument(
+    "--flickr_image_dir_path",
+    type=str,
+    help="Path to the flickr30/flickr30k_images directory.",
+    default=None,
+)
+parser.add_argument(
+    "--flickr_annotations_json_path",
+    type=str,
+    help="Path to the dataset_flickr30k_coco_style.json file.",
+    default=None,
+)
+
+## COCO Dataset
+parser.add_argument(
+    "--coco_image_dir_path",
+    type=str,
+    help="Path to the flickr30/flickr30k_images directory.",
+    default=None,
+)
+parser.add_argument(
+    "--coco_annotations_json_path",
+    type=str,
+    default=None,
+)
+
+## VQAV2 Dataset
+parser.add_argument(
+    "--vqav2_image_dir_path",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--vqav2_questions_json_path",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--vqav2_annotations_json_path",
+    type=str,
+    default=None,
+)
+
+## OK-VQA Dataset
+parser.add_argument(
+    "--ok_vqa_image_dir_path",
+    type=str,
+    help="Path to the vqav2/train2014 directory.",
+    default=None,
+)
+parser.add_argument(
+    "--ok_vqa_questions_json_path",
+    type=str,
+    help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
+    default=None,
+)
+parser.add_argument(
+    "--ok_vqa_annotations_json_path",
+    type=str,
+    help="Path to the v2_mscoco_train2014_annotations.json file.",
+    default=None,
+)
+
+## Imagenet dataset
+parser.add_argument("--imagenet_root", type=str, default="/tmp")
+
+## RefCOCO dataset
+parser.add_argument("--refcoco_tsvfile", type=str, default=None)
+
+parser.add_argument(
+    "--location_token_num",
+    default=1000,
+    type=int,
+)
+# distributed training
+parser.add_argument(
+    "--dist-url",
+    default="env://",
+    type=str,
+    help="url used to set up distributed training",
+)
+parser.add_argument(
+    "--dist-backend", default="nccl", type=str, help="distributed backend"
+)
+parser.add_argument(
+    "--horovod",
+    default=False,
+    action="store_true",
+    help="Use horovod for distributed training.",
+)
+parser.add_argument(
+    "--no-set-device-rank",
+    default=False,
+    action="store_true",
+    help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+)
+parser.add_argument(
+    "--dist",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--lora",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--lora_r",
+    default=16,
+    type=int,
+    required=False,
+)
+parser.add_argument(
+    "--legacy",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--special",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--id",
+    default=0,
+    type=int,
+    required=False,
+)
+
+parser.add_argument(
+    "--eval_gqa",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--use_sam",
+    default=None,
+    type=str,
+    required=False,
+)
+parser.add_argument(
+    "--add_visual_token",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--use_format_v2",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_aro",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_pisc",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_reg",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_vlc",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_crepe",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--level",
+    default=4,
+    type=int,
+)
+parser.add_argument(
+    "--type",
+    default="swap",
+    type=str,
+)
+
+
+class OKVQAPostProcess():
+    def __init__(self):
+        self._lemmatizer = None
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+        
+
+def main():
+    args = parser.parse_args()
+    if args.dist:
+        args.local_rank, args.rank, args.world_size = world_info_from_env()
+        print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
+        device_id = init_distributed_device(args)
+    else:
+        args.rank = 0
+        args.world_size = 1
+        print(f"rank: {args.rank} world_size: {args.world_size}")
+    
+    if "sam" in args.checkpoint_path:
+        args.use_sam = "vit_l"
+
+    args.add_visual_token = True
+    if "lora" in args.checkpoint_path:
+        args.lora = True
+
+
+    args.add_pe = False
+    args.add_box = True
+    args.relation = False
+    args.enhance_data = False
+    args.use_format_v2 = True
+
+
+
+    import hashlib
+    args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
+
+    # load model
+    flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
+        args.vision_encoder_path,
+        args.vision_encoder_pretrained,
+        args.lm_path,
+        args.lm_tokenizer_path,
+        location_token_num=args.location_token_num,
+        lora=args.lora,
+        lora_r=16,
+        use_sam=args.use_sam,
+        add_visual_token=args.add_visual_token,
+        use_format_v2=args.use_format_v2,
+        add_box=args.add_box,
+        add_pe=args.add_pe,
+        add_relation=args.relation,
+        enhance_data=args.enhance_data,
+    )
+    flamingo.use_format_v2 = args.use_format_v2
+    if args.special:
+        flamingo.special = True
+    else:
+        flamingo.special = False
+    if args.legacy:
+        flamingo.legacy = True
+        print("use legacy evaluation")
+    flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
+    flamingo.expr_name = args.checkpoint_path.split("/")[-2]
+    if args.rank == 0:
+        print("legacy", True if hasattr(flamingo, "legacy") else False)
+        print("step:", flamingo.step_num)
+        print("expr:", flamingo.expr_name)
+        print("use format v2:", flamingo.use_format_v2)
+        print(args)
+    checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
+    model_state_dict = {}
+    for key in checkpoint["model_state_dict"].keys():
+        model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
+    if "vision_encoder.logit_scale"in model_state_dict:
+        # previous checkpoint has some unnecessary weights
+        del model_state_dict["vision_encoder.logit_scale"]
+        del model_state_dict["vision_encoder.visual.proj"]
+        del model_state_dict["vision_encoder.visual.ln_post.weight"]
+        del model_state_dict["vision_encoder.visual.ln_post.bias"]
+    flamingo.load_state_dict(model_state_dict, strict=True)
+    results = defaultdict(list)
+    if args.eval_coco:
+        print("Evaluating on COCO...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                cider_score = evaluate_coco_flickr(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.coco_image_dir_path,
+                    annotations_json_path=args.coco_annotations_json_path,
+                    device=args.device,
+                    seed=seed,
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+                print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
+                scores.append(cider_score)
+            print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
+            results["coco"].append(
+                {"shots": shot, "trials": scores, "mean": np.mean(scores)}
+            )
+
+    if args.eval_ok_vqa:
+        print("Evaluating on OK-VQA...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                ok_vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.ok_vqa_image_dir_path,
+                    questions_json_path=args.ok_vqa_questions_json_path,
+                    annotations_json_path=args.ok_vqa_annotations_json_path,
+                    vqa_dataset="ok_vqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["ok_vqa"].append(
+                {"shots": shot, "score": ok_vqa_score}
+            )
+
+    if args.eval_vqav2:
+        print("Evaluating on VQAv2...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.vqav2_image_dir_path,
+                    questions_json_path=args.vqav2_questions_json_path,
+                    annotations_json_path=args.vqav2_annotations_json_path,
+                    vqa_dataset="vqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["vqav2"].append(
+                {"shots": shot, "score": vqa_score}
+            )
+
+    if args.eval_gqa:
+        print("Evaluating on GQA...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    vqa_dataset="gqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["gqa"].append(
+                {"shots": shot, "score": vqa_score}
+            )
+
+    if args.eval_refcoco:
+        print("Evaluating on RefCOCO...")
+        refcoco_score = evaluate_refcoco(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["refcoco"].append(
+            {"score": refcoco_score}
+        )
+    if args.eval_aro:
+        print("Evaluating on ARO...")
+        aro_score = evaluate_aro(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+            add_relation=args.relation,
+        )
+        results["aro"].append(
+            {"score": aro_score}
+        )
+    if args.eval_pisc:
+        print("Evaluating on ARO...")
+        aro_score = evaluate_pisc(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["pisc"].append(
+            {"score": aro_score}
+        )
+    if args.eval_reg:
+        print("Evaluating on Referring Expression Generation...")
+        cider = evaluate_reg(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["reg"].append(
+            {"score": cider}
+        )
+    if args.eval_vlc:
+        print("Evaluating on VL-checklist...")
+        vlc_score = evaluate_vlc(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["vlc"].append(
+            {"score": vlc_score}
+        )
+    if args.eval_crepe:
+        print("Evaluating on CREPE...")
+        crepe_score = evaluate_crepe(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+            level=args.level,
+            type=args.type,
+        )
+        results["crepe"].append(
+            {"score": crepe_score}
+        )
+
+def prepare_batch_images(batch, image_processor):
+    batch_images = None
+    for b in batch:
+        b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        if batch_images is None:
+            batch_images = b_image
+        else:
+            batch_images = torch.cat([batch_images, b_image], dim=0)
+    return batch_images
+
+def get_outputs(
+    model,
+    batch_images,
+    attention_mask,
+    max_generation_length,
+    min_generation_length,
+    num_beams,
+    length_penalty,
+    input_ids,
+    image_start_index_list=None,
+    image_nums=None,
+    bad_words_ids=None,
+):
+    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+        outputs = model.generate(
+            batch_images,
+            input_ids,
+            attention_mask=attention_mask,
+            max_new_tokens=max_generation_length,
+            min_length=min_generation_length,
+            num_beams=num_beams,
+            length_penalty=length_penalty,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+            bad_words_ids=bad_words_ids,
+        )
+
+    outputs = outputs[:, len(input_ids[0]) :]
+    return outputs
+
+
+def evaluate_coco_flickr(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    image_dir_path,
+    annotations_json_path,
+    seed=42,
+    max_generation_length=20,
+    num_beams=1,
+    length_penalty=-2.0,
+    device=-1,
+    is_flickr=False,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    """Evaluate a model on COCO dataset.
+
+    Args:
+        model (nn.Module): model to evaluate
+        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
+        image_processor : image processor for the model
+        batch_size (int): batch size
+        image_dir_path (str, optional): path to the directory containing the images.
+        annotations_json_path (str, optional): path to the json file containing the annotations.
+        seed (int, optional): seed for random number generator. Defaults to 42.
+        max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
+        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
+        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
+        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
+        query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
+        num_shots (int, optional): number of in-context samples to use. Defaults to 8.
+        device (int, optional): device to use. Defaults to -1.
+        num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
+        is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
+
+    Returns:
+        float: CIDEr score
+
+    """
+    # eval_dataset = COCOFlickrDataset(
+    #     image_dir_path=image_dir_path,
+    #     annotations_path=annotations_json_path,
+    #     is_flickr=is_flickr,
+    # )
+    coco_dataset = load_dataset("coco_caption")
+    eval_dataset = coco_dataset["test"]
+
+
+    model.eval().cuda()
+    predictions = defaultdict()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    # if "peft" in lang_encoder_name:
+        # lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
+    try:
+        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    except:
+        pass
+
+    def get_prompt(sample):
+        return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
+
+    tokenizer.padding_side = "left"
+    cnt = 0
+    if world_size > 1:
+        torch.distributed.barrier()
+    desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
+    )):
+        if ii % world_size != rank:
+            continue
+        cnt += len(batch)
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        batch_text = [get_prompt(s) for s in batch]
+        encodings = tokenizer(
+            batch_text,
+            padding="longest",
+            truncation=True,
+            return_tensors="pt",
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"].cuda()
+        attention_mask = encodings["attention_mask"].cuda()
+        skip_special_tokens = False
+        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
+            if rank == 0:
+                tqdm.write("use legacy model")
+            skip_special_tokens = True
+            for i in range(len(input_ids)):
+                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
+                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
+                input_ids[i, media_token_index - 1] = media_token_id
+                input_ids[i, media_token_index] = pad_token_id
+                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
+                input_ids[i, endofmedia_token_index] = bos_token_id
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        if "llama" in lang_encoder_name:
+            attention_mask[input_ids == 0] = 0
+        outputs = get_outputs(
+            model=model,
+            batch_images=batch_images,
+            attention_mask=attention_mask,
+            max_generation_length=30,
+            min_generation_length=8,
+            num_beams=5,
+            length_penalty=0,
+            input_ids=input_ids,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+        )
+        new_predictions = [
+            postprocess_captioning_generation(out).replace('"', "")
+            for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
+        ]
+        # if rank == 0:
+        #     tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
+
+        for i, sample in enumerate(batch):
+            predictions[int(sample["image_id"])] = {
+                "caption": new_predictions[i],
+            }
+    results_path = (
+        f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
+        if is_flickr
+        else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
+    )
+    with open(results_path, "w") as f:
+        f.write(
+            json.dumps(
+                [
+                    {"image_id": k, "caption": predictions[k]["caption"]}
+                    for k in predictions
+                ],
+                indent=2,
+            )
+        )
+    print("save to", results_path)
+    del predictions
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            part_results_path = (
+                f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
+                if is_flickr
+                else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
+            )
+            print("load", part_results_path)
+            predictions.extend(json.load(open(part_results_path)))
+            os.remove(part_results_path)
+        print("num:", len(predictions))
+        results_path = (
+            f"flickrresults_{lang_encoder_name}.json"
+            if is_flickr
+            else f"cocoresults_{lang_encoder_name}.json"
+        )
+        json.dump(predictions, open(results_path, "w"), indent=2)
+
+        metrics = compute_cider(
+            result_path=results_path,
+            annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
+        )
+        os.makedirs("eval_results", exist_ok=True)
+        acc = metrics["CIDEr"]
+        with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(results_path)
+    else:
+        metrics = {}
+        metrics["CIDEr"] = 0.0
+
+    return metrics["CIDEr"]
+
+
+def evaluate_vqa(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    image_dir_path=None,
+    questions_json_path=None,
+    annotations_json_path=None,
+    vqa_dataset="vqa",
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    """
+    Evaluate a model on VQA datasets. Currently supports VQA v2.0.
+
+    Args:
+        model (nn.Module): model to evaluate
+        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
+        image_processor : image processor for the model
+        batch_size (int): batch size
+        image_dir_path (str): path to image directory
+        questions_json_path (str): path to questions json file
+        annotations_json_path (str): path to annotations json file
+        seed (int, optional): random seed. Defaults to 42.
+        max_generation_length (int, optional): max generation length. Defaults to 5.
+        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
+        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
+        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
+        query_set_size (int, optional): size of the query set. Defaults to 2048.
+        num_shots (int, optional): number of shots to use. Defaults to 8.
+        device (int, optional): device to use. Defaults to -1 (cpu).
+        num_workers (int, optional): number of workers to use. Defaults to 4.
+        vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
+    Returns:
+        float: accuracy score
+    """
+    if world_size > 1:
+        torch.distributed.barrier()
+    if vqa_dataset == "gqa":
+        eval_dataset = GQADataset()
+    else:
+        eval_dataset = VQADataset(
+            image_dir_path=image_dir_path,
+            question_path=questions_json_path,
+            annotations_path=annotations_json_path,
+            vqa_dataset=vqa_dataset,
+        )
+    postprocessor = OKVQAPostProcess()
+    try:
+        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    except:
+        pass
+    def get_prompt(sample):
+        return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
+        # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
+
+    model.eval().cuda()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    if "peft" in lang_encoder_name:
+        lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
+    predictions = []
+    tokenizer.padding_side = "left"
+    if world_size > 1:
+        torch.distributed.barrier()
+    this_tot = 0
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
+    )):
+        if ii % world_size != rank:
+            continue
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        batch_text = [get_prompt(s) for s in batch]
+        encodings = tokenizer(
+            batch_text,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"].cuda()
+        attention_mask = encodings["attention_mask"].cuda()
+        skip_special_tokens = True
+        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
+            if rank == 0:
+                tqdm.write("use legacy model")
+            for i in range(len(input_ids)):
+                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
+                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
+                input_ids[i, media_token_index - 1] = media_token_id
+                input_ids[i, media_token_index] = pad_token_id
+                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
+                input_ids[i, endofmedia_token_index] = bos_token_id
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        if "llama" in lang_encoder_name:
+            attention_mask[input_ids == 0] = 0
+        outputs = get_outputs(
+            model=model,
+            batch_images=batch_images,
+            attention_mask=attention_mask,
+            max_generation_length=10,
+            min_generation_length=1,
+            num_beams=5,
+            length_penalty=0,
+            input_ids=input_ids,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+        )
+        # postprocess begin
+        new_predictions = [
+            out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
+        ]
+        if vqa_dataset == "ok_vqa":
+            new_predictions = postprocessor._lemmatize(new_predictions)
+        if model.special:
+            for i in range(len(new_predictions)):
+                for answer, _ in Counter(batch[i]['answers']).most_common():
+                    if answer in new_predictions[i]:
+                        new_predictions[i] = answer
+                        break
+                    if "cant" in new_predictions[i] and "no" == answer:
+                        new_predictions[i] = answer
+                        break
+                    if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
+                        new_predictions[i] = answer
+                        break
+        
+        this_tot += 1
+        if rank == 0 and this_tot % 20 == 0:
+            for i in range(1):
+                tqdm.write("model output: " + new_predictions[i])
+
+        predictions.extend(
+            [
+                {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
+                for p, sample in zip(new_predictions, batch)
+            ]
+        )
+    with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps(predictions))
+    print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
+
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+            predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
+            os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+        print("num:", len(predictions))
+        # save the predictions to a temporary file
+        random_uuid = str(uuid.uuid4())
+        with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
+            f.write(json.dumps(predictions, indent=4))
+
+        if vqa_dataset == "gqa":
+            acc = compute_gqa_accuracy(predictions)
+        else:
+            acc = compute_vqa_accuracy(
+                f"{vqa_dataset}results_{random_uuid}.json",
+                questions_json_path,
+                annotations_json_path,
+                vqa_dataset=vqa_dataset,
+            )
+        print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
+        os.makedirs("eval_results", exist_ok=True)
+        with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(f"{vqa_dataset}results_{random_uuid}.json")
+    else:
+        time.sleep(5)
+        acc = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return acc
+
+
+def evaluate_refcoco(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    model.eval().cuda()
+    loc_token_ids = []
+    for i in range(1000):
+        loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    # all_ids = set(range(model.lang_encoder.lm_head.out_features))
+    # bad_words_ids = list(all_ids - set(loc_token_ids))
+    # bad_words_ids = [[b] for b in bad_words_ids]
+    # min_loc_token_id = min(loc_token_ids)
+    # max_loc_token_id = max(loc_token_ids)
+    total = 0
+    correct = 0
+    ious = []
+    if "refcocog" in tsvfile:
+        dataset_name = "refcocog"
+    elif "refcocoplus" in tsvfile:
+        dataset_name = "refcocoplus"
+    else:
+        dataset_name = "refcoco"
+    with open(tsvfile, "r") as f:
+        lines = f.readlines()
+        pbar = tqdm(lines, disable=(rank != 0))
+        for ii, line in enumerate(pbar):
+            if ii % world_size != rank:
+                continue
+            total += 1
+            line = line.rstrip()
+            uniq_id, image_id, text, region_coord, image = line.split("\t")
+
+            image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
+
+            gt_box = np.array(list(map(float, region_coord.split(","))))
+            width = image.width
+            height = image.height
+            image = image.resize((224, 224))
+            gt_box = gt_box / np.array([width, height, width, height]) * 224
+            batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
+
+
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            # attention_mask[input_ids == prebox_token_id] = 0
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+
+            model.debug_id = 0
+            with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=None,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=None,
+                    add_box=False,
+                )
+            boxes = outputs["boxes"]
+            scores = outputs["scores"]
+            if len(scores) > 0:
+                box = boxes[scores.argmax()]
+                iou = get_iou(box, gt_box)
+            else:
+                iou = 0.0
+                # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
+                tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}")
+            if iou >= 0.5:
+                correct += 1
+            pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
+            # open_cv_image = np.array(image)
+            # # Convert RGB to BGR 
+            # open_cv_image = open_cv_image[:, :, ::-1].copy() 
+            # for box, score in zip(boxes, scores):
+            #     open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
+            # cv2.imwrite("output.jpg", open_cv_image)
+            # print(boxes)
+            # print(scores)
+            # exit()
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+
+# def preprocess_visual_info(Text):
+#     text = Text.split(" ")
+#     for is_idx, t in enumerate(text):
+#         if t == "is":
+#             break
+#     the_idx = is_idx
+#     while text[the_idx] != "the":
+#         the_idx -= 1
+#     obj_A = " ".join(text[the_idx+1:is_idx])
+#     second_the_idx = len(text) - 1
+#     while text[second_the_idx] != "the":
+#         second_the_idx -= 1
+#     obj_B =  " ".join(text[second_the_idx+1:])
+#     visual_obj_A = f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
+#     visual_obj_B = f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
+#     Text = Text.replace(obj_A, f"<|#object#|>{obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
+#     Text = Text.replace(obj_B, f"<|#object#|>{obj_B}<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>")
+#     return Text, obj_A, obj_B, visual_obj_A, visual_obj_B
+
+
+def preprocess_visual_info(Text):
+    text = Text.split(" ")
+    for is_idx, t in enumerate(text):
+        if t == "is":
+            break
+    the_idx = is_idx
+    while text[the_idx] != "the":
+        the_idx -= 1
+    obj_A = " ".join(text[the_idx+1:is_idx])
+    second_the_idx = len(text) - 1
+    while text[second_the_idx] != "the":
+        second_the_idx -= 1
+    obj_B = " ".join(text[second_the_idx+1:])
+    relation = " ".join(text[is_idx+1:second_the_idx])
+    visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
+    visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
+    Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
+    return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
+
+
+
+
+def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
+    assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
+    encodings = tokenizer(
+        prompt,
+        padding="longest",
+        truncation=True,
+        return_tensors="pt",
+        max_length=2000,
+    )
+    input_ids = encodings["input_ids"]
+    attention_mask = encodings["attention_mask"]
+    image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+    image_start_index_list = [[x] for x in image_start_index_list]
+    image_nums = [1] * len(input_ids)
+    vision_x = batch_images.cuda()
+    lang_x = input_ids.cuda()
+    attention_mask = attention_mask.cuda()
+
+    model.debug_id = 0
+    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+        outputs = model(
+            vision_x=vision_x,
+            lang_x=lang_x,
+            attention_mask=attention_mask,
+            labels=None,
+            image_nums=image_nums,
+            image_start_index_list=image_start_index_list,
+            added_bbox_list=visual_box_list,
+            add_box=visual_box_list is not None,
+            relations=None,
+            debug_mode=False,
+        )
+    boxes = outputs["boxes"]
+    scores = outputs["scores"]
+    if debug:
+        import pdb; pdb.set_trace()
+    if return_all:
+        return boxes, scores
+    if len(scores) == 0:
+        return None, None
+    else:
+        return boxes[scores.argmax()], scores.max()
+
+
+def evaluate_aro(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    add_visual=True,
+    add_relation=False,
+    subset=False,
+    choose_left_right=True,
+):
+    both_failed_ids = json.load(open("both_failed_ids.json"))
+    os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
+    # from groundingdino.demo.caption_grounder import caption_grounder
+    # generator = caption_grounder(
+    #     config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
+    #     checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
+    #     cpu_only=False,
+    #     box_threshold=0.1, text_threshold=0.1,
+    # )
+    dataset_name = "aro"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    model.eval().cuda()
+    total = 0
+    correct = 0
+    from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
+    vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
+    with open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/unilm/kosmos-2/labels.json") as f:
+        all_labels = json.load(f)
+        label_ids = tokenizer(all_labels).input_ids
+        label_ids = sorted(list(set([x[0] for x in label_ids])))
+
+    if subset:
+        subset_idx = json.load(open("aro_subset.json"))
+        pbar = tqdm(subset_idx, disable=(rank != 0))
+    else:
+        pbar = tqdm(vgr_dataset, disable=(rank != 0))
+    for ii, sample in enumerate(pbar):
+        if subset:
+            ORI_IDX = int(sample)
+            sample = vgr_dataset[sample]
+            # if ORI_IDX != 19036:
+            #     continue
+        if ii % world_size != rank:
+            continue
+
+        # not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
+        # if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
+        #     if rank == 0:
+        #         tqdm.write(f"SKIP: {sample['caption_options'][1]}")
+        #     continue
+        total += 1
+        # image = sample["image_options"][0]
+        image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/man_on_hydrant.png").convert("RGB")
+        image = image.resize((224, 224))
+
+        # text = sample["caption_options"][1] # 1 is true caption
+        text = "the man is sitting on the fire hydrant"
+        batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
+
+
+        first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
+        first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
+
+
+        # use grounding DINO to get the first bbox
+        # caption = f"{obj_A}"
+        # with torch.no_grad():
+        #     logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption)
+        #     boxes_filt, pred_phrases = generator.postprocess(logits, boxes, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
+        # objects = {}
+        # for box, phrase in zip(boxes_filt, pred_phrases):
+        #     obj, score = phrase
+        #     obj = obj[0]
+        #     if obj not in objects:
+        #         objects[obj] = (score, box)
+        #     if objects[obj][0] < score:
+        #         objects[obj] = (score, box)
+        # try:
+        #     first_box = objects[obj_A][1].clone()
+        #     first_box[:2] -= first_box[2:] / 2
+        #     first_box[2:] += first_box[:2]
+        #     first_box = first_box.clamp(0, 0.99) * 224.0
+        #     first_box = first_box.numpy()
+        #     first_score = objects[obj_A][0]
+        # except:
+        #     first_box = None
+
+        if first_box is None:
+            text_A = "the " + obj_A
+            added_bbox_list = None
+        else:
+            text_A = visual_obj_A
+            added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
+
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
+        pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, 
+        prebox_token_id, return_all=True)
+
+
+        # open_cv_image = np.array(image)
+        # open_cv_image = open_cv_image[:, :, ::-1].copy()
+        # for box, score in zip(pre_box, pre_score):
+        #     print(box, score)
+        #     if score > 0.1:
+        #         open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (0, 255, 0), 2)
+        # cv2.imwrite(f"test1.jpg", open_cv_image)
+        # print(sample["caption_options"][idx])
+        # exit()
+
+
+
+        if pre_boxes is None:
+            pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
+            pre_scores = [1.0]
+
+        rank_list = []
+        # pre_boxes = [pre_boxes[0]]
+        # pre_scores = [pre_scores[0]]
+        for pre_box, pre_score in zip(pre_boxes, pre_scores):
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
+
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=512,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+            labels = lang_x.clone()
+
+            answer_start_idx = (labels == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] + 1
+            # pre_box = None
+            labels[0, :answer_start_idx] = -100
+            # # labels[labels == endofobject_token_id] = -100
+            # labels[:, 0] = -100
+            # labels[labels == visual_token_id] = -100
+            # labels[labels == box_token_id] = -100
+            # labels[labels == previsual_token_id] = -100
+            # labels[labels == prebox_token_id] = -100
+            # labels[labels == endofattr_token_id] = -100
+            # labels[labels == tokenizer.pad_token_id] = -100
+            # labels[labels == media_token_id] = -100
+            # labels[labels == endofmedia_token_id] = -100
+            answer_ids = tokenizer(f" {obj_B}", add_special_tokens=False)["input_ids"]
+            labels[input_ids == visual_token_id] = -100
+            labels[input_ids == box_token_id] = -100
+            labels[input_ids == endofattr_token_id] = -100
+            labels[input_ids == previsual_token_id] = -100
+            labels[input_ids == prebox_token_id] = -100
+            labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
+            labels[torch.roll(input_ids == box_token_id, 1)] = -100
+            labels[:, 0] = -100
+            labels[input_ids == tokenizer.pad_token_id] = -100
+            labels[input_ids == media_token_id] = -100
+            labels[input_ids == endofmedia_token_id] = -100
+
+            added_bbox_list = None
+            if add_visual:
+                added_bbox_list = []
+                if first_box is not None:
+                    added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
+                if pre_box is not None:
+                    added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
+            if added_bbox_list is not None and len(added_bbox_list) == 0:
+                added_bbox_list = None
+
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=labels,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=added_bbox_list,
+                    add_box=added_bbox_list is not None,
+                    relations=None,
+                )
+            logits = outputs["logits"][0, answer_start_idx:]
+            # _rank = logits[0][label_ids].sort(descending=True).indices.tolist().index(label_ids.index(answer_ids[0]))
+            _rank = logits[0].sort(descending=True).indices.tolist().index(answer_ids[0])
+            print(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10]))
+            print(tokenizer.decode(logits[1].sort(descending=True).indices.tolist()[:10]))
+            rank_list.append(_rank)
+            # open_cv_image = np.array(image)
+            # open_cv_image = open_cv_image[:, :, ::-1].copy()
+            # if first_box is not None:
+            #     open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
+            # if pre_box is not None:
+            #     open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
+
+            # font = cv2.FONT_HERSHEY_SIMPLEX
+            # org = [10, 20]
+            # fontScale = 0.5
+            # color = (0, 0, 0)
+            # thickness = 1
+            # open_cv_image = cv2.resize(open_cv_image, (512, 512))
+            # put_text = sample["caption_options"][1]
+            # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
+            # org[1] += 20
+            # put_text = "top10 in green box"
+            # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
+            # fontScale = 1.0
+            # thickness = 2
+            # for ind in logits_list[i][0].sort(descending=True).indices[:10]:
+            #     org[1] += 20
+            #     put_text = f"{tokenizer.decode(ind)}"
+            #     open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
+            # tqdm.write(f"{tokenizer.decode(logits_list[i][0].sort(descending=True).indices[:10])}")
+        # tqdm.write(f"{rank_list}")
+        final_rank = min(rank_list)
+        if final_rank < 10:
+            correct += 1
+            TYPE = "CORRECT"
+            # if ii in both_failed_ids:
+            #     tqdm.write(f"case find->{sample['caption_options'][1]}")
+            #     image.save(f"case_study/{ii}_{rank_list}_{sample['caption_options'][1]}.jpg")
+            if rank == 0:
+                tqdm.write(f"correct: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
+        else:
+            TYPE = "WRONG"
+            if rank == 0:
+                tqdm.write(f"wrong: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
+        # cv2.imwrite(f"visualization/aro_results_{id}/{TYPE}_{ORI_IDX}.jpg", open_cv_image)
+        pbar.set_description(f"score: {correct / total:.4f} | {final_rank}")
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score, "total:", total)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+def evaluate_pisc(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    add_visual=True,
+):
+    from open_flamingo.train.instruction_template import PISC_TEMPLATES
+    dataset_name = "pisc"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    model.train().cuda()
+
+    dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
+    pbar = tqdm(dataset, disable=(rank != 0))
+
+    rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
+    rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
+    gt = []
+    pred_scores = []
+    for III, sample in enumerate(pbar):
+        if III % world_size != rank:
+            continue
+        image_path, dataset, data = sample
+        image = Image.open(image_path)
+        size = image_processor.transforms[0].size
+        image = image.resize((size, size))
+        batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        boxA = data[0]
+        boxB = data[1]
+        gt_relation = data[2]
+        losses = []
+        for i_rel, option_rel in enumerate(rel_id_to_type):
+            text = PISC_TEMPLATES[0].format(relation=option_rel)
+            added_bbox = [
+                torch.tensor([boxA]).cuda(),
+                torch.tensor([boxB]).cuda(),
+            ]
+            caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
+            encodings = tokenizer(
+                caption,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+
+            labels = lang_x.clone()
+            labels[labels == tokenizer.pad_token_id] = -100
+            if add_visual:
+                # endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
+                # endofattr_next_token_index[1] += 1
+                # endofattr_next_token_id = labels[endofattr_next_token_index]
+                # </obj><visual><box></attr>NEXT_WORD
+                # </obj> predict NEXT_WORD
+                # <visual><box></attr> predict nothing
+                labels[labels == visual_token_id] = -100
+                labels[labels == box_token_id] = -100
+                labels[labels == endofattr_token_id] = -100
+                # labels[endofattr_next_token_index] = -100
+            labels[:, 0] = -100
+            answer_token_id = tokenizer(" Answer").input_ids[0]
+            answer_token_loc = (input_ids == answer_token_id).nonzero()
+            for batch_idx, idx in answer_token_loc:
+                labels[batch_idx][:idx+2] = -100
+
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=labels,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=added_bbox,
+                    add_box=added_bbox is not None,
+                )
+                loss_total = outputs.loss.reshape(labels.shape[0], -1)
+                loss = loss_total.sum() / (loss_total != 0).sum()
+                losses.append(loss.item())
+        pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
+        gt.append(rel_type_to_id[gt_relation])
+    gt = np.array(gt)
+    pred_scores = np.array(pred_scores)
+    pred = pred_scores.argmax(1)
+
+
+    print("total num:", len(gt))
+    recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
+    print("recalls:", recalls)
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([gt.tolist(), pred.tolist()]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        gt = []
+        pred = []
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            gt.extend(gt_part)
+            pred.extend(pred_part)
+        print("total num:", len(gt))
+        recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
+        print("recalls:", recalls)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
+            f.write(f"{gt}\n")
+            f.write(f"{pred}\n")
+            f.write(f"{recalls}\n")
+    score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/open_flamingo/eval/evaluate_temp.py b/multimodal/open_flamingo/eval/evaluate_temp.py
new file mode 100644
index 0000000000000000000000000000000000000000..38dfd2a3ecd8e9a9066427f36fa64f2ed07a194f
--- /dev/null
+++ b/multimodal/open_flamingo/eval/evaluate_temp.py
@@ -0,0 +1,1838 @@
+import argparse
+import json
+from math import ceil
+import os
+import random
+import uuid
+from collections import defaultdict
+from typing import Callable
+import time
+import cv2
+import webdataset as wds
+from sklearn.metrics import recall_score, average_precision_score
+
+import more_itertools
+import numpy as np
+import torch
+from coco_metric import compute_cider, postprocess_captioning_generation
+from eval_datasets import VQADataset, GQADataset
+from tqdm import tqdm
+from collections import Counter
+
+from vqa_metric import compute_vqa_accuracy, compute_gqa_accuracy
+from open_flamingo.eval.classification import (
+    compute_per_sample_probs,
+    compute_per_sample_loss,
+)
+from open_flamingo.eval.imagenet_utils import (
+    openai_imagenet_classnames,
+    IMAGENET_1K_CLASS_ID_TO_LABEL,
+)
+
+from open_flamingo.src.factory import create_model_and_transforms
+from PIL import Image
+from io import BytesIO
+import base64
+from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
+import string
+from lavis.datasets.builders import load_dataset
+
+
+def get_iou(box1, box2):
+    # box1 and box2 should be in the format [x1, y1, x2, y2]
+    intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
+                   max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
+    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+    union = area_box1 + area_box2 - intersection
+    iou = intersection / union if union > 0 else 0
+    return iou
+
+def expand2square(pil_img, background_color):
+    width, height = pil_img.size
+    if width == height:
+        return pil_img
+    elif width > height:
+        result = Image.new(pil_img.mode, (width, width), background_color)
+        result.paste(pil_img, (0, (width - height) // 2))
+        return result
+    else:
+        result = Image.new(pil_img.mode, (height, height), background_color)
+        result.paste(pil_img, ((height - width) // 2, 0))
+        return result
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--lm_path", type=str, default="facebook/opt-1.3b")
+parser.add_argument("--lm_tokenizer_path", type=str, default="facebook/opt-30b")
+parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
+parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
+parser.add_argument("--checkpoint_path", type=str, required=True)
+parser.add_argument(
+    "--results_file", type=str, default=None, help="JSON file to save results"
+)
+
+# Trial arguments
+parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
+parser.add_argument(
+    "--num_trials",
+    type=int,
+    default=1,
+    help="Number of trials to run for each shot using different demonstrations",
+)
+parser.add_argument(
+    "--trial_seeds",
+    nargs="+",
+    default=[0],
+    help="Seeds to use for each trial for picking demonstrations and eval sets",
+)
+parser.add_argument(
+    "--num_samples", type=int, default=5000, help="Number of samples to evaluate on"
+)
+
+parser.add_argument("--batch_size", type=int, default=8)
+
+# Per-dataset evaluation flags
+parser.add_argument(
+    "--eval_coco",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on COCO.",
+)
+parser.add_argument(
+    "--eval_vqav2",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on VQAV2.",
+)
+parser.add_argument(
+    "--eval_ok_vqa",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on OK-VQA.",
+)
+parser.add_argument(
+    "--eval_imagenet",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on ImageNet.",
+)
+
+parser.add_argument(
+    "--eval_flickr30",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on Flickr30.",
+)
+
+parser.add_argument(
+    "--eval_refcoco",
+    action="store_true",
+    default=False,
+    help="Whether to evaluate on RefCOCO.",
+)
+
+# Dataset arguments
+
+## Flickr30 Dataset
+parser.add_argument(
+    "--flickr_image_dir_path",
+    type=str,
+    help="Path to the flickr30/flickr30k_images directory.",
+    default=None,
+)
+parser.add_argument(
+    "--flickr_annotations_json_path",
+    type=str,
+    help="Path to the dataset_flickr30k_coco_style.json file.",
+    default=None,
+)
+
+## COCO Dataset
+parser.add_argument(
+    "--coco_image_dir_path",
+    type=str,
+    help="Path to the flickr30/flickr30k_images directory.",
+    default=None,
+)
+parser.add_argument(
+    "--coco_annotations_json_path",
+    type=str,
+    default=None,
+)
+
+## VQAV2 Dataset
+parser.add_argument(
+    "--vqav2_image_dir_path",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--vqav2_questions_json_path",
+    type=str,
+    default=None,
+)
+parser.add_argument(
+    "--vqav2_annotations_json_path",
+    type=str,
+    default=None,
+)
+
+## OK-VQA Dataset
+parser.add_argument(
+    "--ok_vqa_image_dir_path",
+    type=str,
+    help="Path to the vqav2/train2014 directory.",
+    default=None,
+)
+parser.add_argument(
+    "--ok_vqa_questions_json_path",
+    type=str,
+    help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.",
+    default=None,
+)
+parser.add_argument(
+    "--ok_vqa_annotations_json_path",
+    type=str,
+    help="Path to the v2_mscoco_train2014_annotations.json file.",
+    default=None,
+)
+
+## Imagenet dataset
+parser.add_argument("--imagenet_root", type=str, default="/tmp")
+
+## RefCOCO dataset
+parser.add_argument("--refcoco_tsvfile", type=str, default=None)
+
+parser.add_argument(
+    "--location_token_num",
+    default=1000,
+    type=int,
+)
+# distributed training
+parser.add_argument(
+    "--dist-url",
+    default="env://",
+    type=str,
+    help="url used to set up distributed training",
+)
+parser.add_argument(
+    "--dist-backend", default="nccl", type=str, help="distributed backend"
+)
+parser.add_argument(
+    "--horovod",
+    default=False,
+    action="store_true",
+    help="Use horovod for distributed training.",
+)
+parser.add_argument(
+    "--no-set-device-rank",
+    default=False,
+    action="store_true",
+    help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+)
+parser.add_argument(
+    "--dist",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--lora",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--lora_r",
+    default=16,
+    type=int,
+    required=False,
+)
+parser.add_argument(
+    "--legacy",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--special",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--id",
+    default=0,
+    type=int,
+    required=False,
+)
+
+parser.add_argument(
+    "--eval_gqa",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--use_sam",
+    default=None,
+    type=str,
+    required=False,
+)
+parser.add_argument(
+    "--add_visual_token",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--use_format_v2",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_aro",
+    default=False,
+    action="store_true",
+)
+parser.add_argument(
+    "--eval_pisc",
+    default=False,
+    action="store_true",
+)
+
+
+class OKVQAPostProcess():
+    def __init__(self):
+        self._lemmatizer = None
+
+    def _lemmatize(self, answers):
+        def apply(answer):
+            doc = self.lemmatizer(answer)
+
+            words = []
+            for token in doc:
+                if token.pos_ in ["NOUN", "VERB"]:
+                    words.append(token.lemma_)
+                else:
+                    words.append(token.text)
+            answer = " ".join(words)
+
+            return answer
+
+        return [apply(answer) for answer in answers]
+
+    @property
+    def lemmatizer(self):
+        if self._lemmatizer is None:
+            try:
+                import spacy
+
+                self._lemmatizer = spacy.load("en_core_web_sm")
+            except ImportError:
+                logging.error(
+                    """
+                    Please install spacy and en_core_web_sm model to apply lemmatization.
+                    python -m spacy download en_core_web_sm
+                    OR
+                    import spacy.cli
+                    spacy.cli.download("en_core_web_sm")
+                    """
+                )
+                exit(1)
+
+        return self._lemmatizer
+        
+
+def main():
+    args = parser.parse_args()
+    if args.dist:
+        args.local_rank, args.rank, args.world_size = world_info_from_env()
+        print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
+        device_id = init_distributed_device(args)
+    else:
+        args.rank = 0
+        args.world_size = 1
+        print(f"rank: {args.rank} world_size: {args.world_size}")
+    
+    if "sam" in args.checkpoint_path:
+        args.use_sam = "vit_l"
+
+    args.add_visual_token = True
+    if "lora" in args.checkpoint_path:
+        args.lora = True
+
+
+    args.add_pe = False
+    args.add_box = True
+    args.relation = False
+    args.enhance_data = False
+    args.use_format_v2 = True
+
+
+
+    import hashlib
+    args.id = hashlib.sha224(args.checkpoint_path.encode()).hexdigest()
+
+    # load model
+    flamingo, image_processor, tokenizer, vis_embed_size = create_model_and_transforms(
+        args.vision_encoder_path,
+        args.vision_encoder_pretrained,
+        args.lm_path,
+        args.lm_tokenizer_path,
+        location_token_num=args.location_token_num,
+        lora=args.lora,
+        lora_r=16,
+        use_sam=args.use_sam,
+        add_visual_token=args.add_visual_token,
+        use_format_v2=args.use_format_v2,
+        add_box=args.add_box,
+        add_pe=args.add_pe,
+        add_relation=args.relation,
+        enhance_data=args.enhance_data,
+    )
+    flamingo.use_format_v2 = args.use_format_v2
+    if args.special:
+        flamingo.special = True
+    else:
+        flamingo.special = False
+    if args.legacy:
+        flamingo.legacy = True
+        print("use legacy evaluation")
+    flamingo.step_num = int(args.checkpoint_path.split("/")[-1].split(".")[0].split("_")[-1])
+    flamingo.expr_name = args.checkpoint_path.split("/")[-2]
+    if args.rank == 0:
+        print("legacy", True if hasattr(flamingo, "legacy") else False)
+        print("step:", flamingo.step_num)
+        print("expr:", flamingo.expr_name)
+        print("use format v2:", flamingo.use_format_v2)
+        print(args)
+    checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
+    model_state_dict = {}
+    for key in checkpoint["model_state_dict"].keys():
+        model_state_dict[key.replace("module.", "")] = checkpoint["model_state_dict"][key]
+    if "vision_encoder.logit_scale"in model_state_dict:
+        # previous checkpoint has some unnecessary weights
+        del model_state_dict["vision_encoder.logit_scale"]
+        del model_state_dict["vision_encoder.visual.proj"]
+        del model_state_dict["vision_encoder.visual.ln_post.weight"]
+        del model_state_dict["vision_encoder.visual.ln_post.bias"]
+    flamingo.load_state_dict(model_state_dict, strict=True)
+    results = defaultdict(list)
+    if args.eval_coco:
+        print("Evaluating on COCO...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                cider_score = evaluate_coco_flickr(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.coco_image_dir_path,
+                    annotations_json_path=args.coco_annotations_json_path,
+                    device=args.device,
+                    seed=seed,
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+                print(f"Shots {shot} Trial {trial} CIDEr score: {cider_score}")
+                scores.append(cider_score)
+            print(f"Shots {shot} Mean CIDEr score: {np.mean(scores)}")
+            results["coco"].append(
+                {"shots": shot, "trials": scores, "mean": np.mean(scores)}
+            )
+
+    if args.eval_ok_vqa:
+        print("Evaluating on OK-VQA...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                ok_vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.ok_vqa_image_dir_path,
+                    questions_json_path=args.ok_vqa_questions_json_path,
+                    annotations_json_path=args.ok_vqa_annotations_json_path,
+                    vqa_dataset="ok_vqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["ok_vqa"].append(
+                {"shots": shot, "score": ok_vqa_score}
+            )
+
+    if args.eval_vqav2:
+        print("Evaluating on VQAv2...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    image_dir_path=args.vqav2_image_dir_path,
+                    questions_json_path=args.vqav2_questions_json_path,
+                    annotations_json_path=args.vqav2_annotations_json_path,
+                    vqa_dataset="vqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["vqav2"].append(
+                {"shots": shot, "score": vqa_score}
+            )
+
+    if args.eval_gqa:
+        print("Evaluating on GQA...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                vqa_score = evaluate_vqa(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    vqa_dataset="gqa",
+                    vis_embed_size=vis_embed_size,
+                    rank=args.rank,
+                    world_size=args.world_size,
+                    id=args.id,
+                )
+            results["gqa"].append(
+                {"shots": shot, "score": vqa_score}
+            )
+
+    if args.eval_imagenet:
+        print("Evaluating on ImageNet...")
+        for shot in args.shots:
+            scores = []
+            for seed, trial in zip(args.trial_seeds, range(args.num_trials)):
+                imagenet_score = evaluate_imagenet(
+                    model=flamingo,
+                    tokenizer=tokenizer,
+                    image_processor=image_processor,
+                    batch_size=args.batch_size,
+                    num_samples=args.num_samples,
+                    num_shots=shot,
+                    device=args.device,
+                    seed=seed,
+                    imagenet_root=args.imagenet_root,
+                )
+                print(
+                    f"Shots {shot} Trial {trial} " f"ImageNet score: {imagenet_score}"
+                )
+                scores.append(imagenet_score)
+            print(f"Shots {shot} Mean ImageNet score: {np.mean(scores)}")
+            results["imagenet"].append(
+                {"shots": shot, "trials": scores, "mean": np.mean(scores)}
+            )
+
+    if args.eval_refcoco:
+        print("Evaluating on RefCOCO...")
+        refcoco_score = evaluate_refcoco(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["refcoco"].append(
+            {"score": refcoco_score}
+        )
+    if args.eval_aro:
+        print("Evaluating on ARO...")
+        _func = evaluate_aro
+        # print("Evaluating on ARO ORI...")
+        # _func = evaluate_aro_ori
+        aro_score = _func(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+            add_relation=args.relation,
+        )
+        results["aro"].append(
+            {"score": aro_score}
+        )
+    if args.eval_pisc:
+        print("Evaluating on ARO...")
+        aro_score = evaluate_pisc(
+            model=flamingo,
+            tokenizer=tokenizer,
+            image_processor=image_processor,
+            batch_size=args.batch_size,
+            device=args.device,
+            tsvfile=args.refcoco_tsvfile,
+            vis_embed_size=vis_embed_size,
+            rank=args.rank,
+            world_size=args.world_size,
+            id=args.id,
+        )
+        results["pisc"].append(
+            {"score": aro_score}
+        )
+
+def prepare_batch_images(batch, image_processor):
+    batch_images = None
+    for b in batch:
+        b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        if batch_images is None:
+            batch_images = b_image
+        else:
+            batch_images = torch.cat([batch_images, b_image], dim=0)
+    return batch_images
+
+def get_outputs(
+    model,
+    batch_images,
+    attention_mask,
+    max_generation_length,
+    min_generation_length,
+    num_beams,
+    length_penalty,
+    input_ids,
+    image_start_index_list=None,
+    image_nums=None,
+    bad_words_ids=None,
+):
+    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+        outputs = model.generate(
+            batch_images,
+            input_ids,
+            attention_mask=attention_mask,
+            max_new_tokens=max_generation_length,
+            min_length=min_generation_length,
+            num_beams=num_beams,
+            length_penalty=length_penalty,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+            bad_words_ids=bad_words_ids,
+        )
+
+    outputs = outputs[:, len(input_ids[0]) :]
+    return outputs
+
+
+def evaluate_coco_flickr(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    image_dir_path,
+    annotations_json_path,
+    seed=42,
+    max_generation_length=20,
+    num_beams=1,
+    length_penalty=-2.0,
+    device=-1,
+    is_flickr=False,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    """Evaluate a model on COCO dataset.
+
+    Args:
+        model (nn.Module): model to evaluate
+        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
+        image_processor : image processor for the model
+        batch_size (int): batch size
+        image_dir_path (str, optional): path to the directory containing the images.
+        annotations_json_path (str, optional): path to the json file containing the annotations.
+        seed (int, optional): seed for random number generator. Defaults to 42.
+        max_generation_length (int, optional): maximum length of the generated caption. Defaults to 10.
+        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
+        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
+        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000.
+        query_set_size (int, optional): number of samples to use for query set. Defaults to 2048.
+        num_shots (int, optional): number of in-context samples to use. Defaults to 8.
+        device (int, optional): device to use. Defaults to -1.
+        num_workers (int, optional): number of workers to use for dataloader. Defaults to 4.
+        is_flickr (bool): defines if that data is COCO or Flickr. Defaults to False (COCO).
+
+    Returns:
+        float: CIDEr score
+
+    """
+    # eval_dataset = COCOFlickrDataset(
+    #     image_dir_path=image_dir_path,
+    #     annotations_path=annotations_json_path,
+    #     is_flickr=is_flickr,
+    # )
+    coco_dataset = load_dataset("coco_caption")
+    eval_dataset = coco_dataset["test"]
+
+
+    model.eval().cuda()
+    predictions = defaultdict()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    # if "peft" in lang_encoder_name:
+        # lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
+    try:
+        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    except:
+        pass
+
+    def get_prompt(sample):
+        return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
+
+    tokenizer.padding_side = "left"
+    cnt = 0
+    if world_size > 1:
+        torch.distributed.barrier()
+    desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
+    )):
+        if ii % world_size != rank:
+            continue
+        cnt += len(batch)
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        batch_text = [get_prompt(s) for s in batch]
+        encodings = tokenizer(
+            batch_text,
+            padding="longest",
+            truncation=True,
+            return_tensors="pt",
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"].cuda()
+        attention_mask = encodings["attention_mask"].cuda()
+        skip_special_tokens = False
+        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
+            if rank == 0:
+                tqdm.write("use legacy model")
+            skip_special_tokens = True
+            for i in range(len(input_ids)):
+                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
+                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
+                input_ids[i, media_token_index - 1] = media_token_id
+                input_ids[i, media_token_index] = pad_token_id
+                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
+                input_ids[i, endofmedia_token_index] = bos_token_id
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        if "llama" in lang_encoder_name:
+            attention_mask[input_ids == 0] = 0
+        outputs = get_outputs(
+            model=model,
+            batch_images=batch_images,
+            attention_mask=attention_mask,
+            max_generation_length=30,
+            min_generation_length=8,
+            num_beams=5,
+            length_penalty=0,
+            input_ids=input_ids,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+        )
+        new_predictions = [
+            postprocess_captioning_generation(out).replace('"', "")
+            for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
+        ]
+        # if rank == 0:
+        #     tqdm.write(f"{batch_images.shape} {batch[0]} pred: {new_predictions[0]}")
+
+        for i, sample in enumerate(batch):
+            predictions[int(sample["image_id"])] = {
+                "caption": new_predictions[i],
+            }
+    results_path = (
+        f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
+        if is_flickr
+        else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
+    )
+    with open(results_path, "w") as f:
+        f.write(
+            json.dumps(
+                [
+                    {"image_id": k, "caption": predictions[k]["caption"]}
+                    for k in predictions
+                ],
+                indent=2,
+            )
+        )
+    print("save to", results_path)
+    del predictions
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            part_results_path = (
+                f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
+                if is_flickr
+                else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
+            )
+            print("load", part_results_path)
+            predictions.extend(json.load(open(part_results_path)))
+            os.remove(part_results_path)
+        print("num:", len(predictions))
+        results_path = (
+            f"flickrresults_{lang_encoder_name}.json"
+            if is_flickr
+            else f"cocoresults_{lang_encoder_name}.json"
+        )
+        json.dump(predictions, open(results_path, "w"), indent=2)
+
+        metrics = compute_cider(
+            result_path=results_path,
+            annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
+        )
+        os.makedirs("eval_results", exist_ok=True)
+        acc = metrics["CIDEr"]
+        with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(results_path)
+    else:
+        metrics = {}
+        metrics["CIDEr"] = 0.0
+
+    return metrics["CIDEr"]
+
+
+def evaluate_vqa(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    image_dir_path=None,
+    questions_json_path=None,
+    annotations_json_path=None,
+    vqa_dataset="vqa",
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    """
+    Evaluate a model on VQA datasets. Currently supports VQA v2.0.
+
+    Args:
+        model (nn.Module): model to evaluate
+        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
+        image_processor : image processor for the model
+        batch_size (int): batch size
+        image_dir_path (str): path to image directory
+        questions_json_path (str): path to questions json file
+        annotations_json_path (str): path to annotations json file
+        seed (int, optional): random seed. Defaults to 42.
+        max_generation_length (int, optional): max generation length. Defaults to 5.
+        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
+        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
+        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
+        query_set_size (int, optional): size of the query set. Defaults to 2048.
+        num_shots (int, optional): number of shots to use. Defaults to 8.
+        device (int, optional): device to use. Defaults to -1 (cpu).
+        num_workers (int, optional): number of workers to use. Defaults to 4.
+        vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
+    Returns:
+        float: accuracy score
+    """
+    if world_size > 1:
+        torch.distributed.barrier()
+    if vqa_dataset == "gqa":
+        eval_dataset = GQADataset()
+    else:
+        eval_dataset = VQADataset(
+            image_dir_path=image_dir_path,
+            question_path=questions_json_path,
+            annotations_path=annotations_json_path,
+            vqa_dataset=vqa_dataset,
+        )
+    postprocessor = OKVQAPostProcess()
+    try:
+        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    except:
+        pass
+    def get_prompt(sample):
+        return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
+        # return f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
+
+    model.eval().cuda()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    if "peft" in lang_encoder_name:
+        lang_encoder_name = model.lang_encoder.base_model.model.__class__.__name__.lower()
+    predictions = []
+    tokenizer.padding_side = "left"
+    if world_size > 1:
+        torch.distributed.barrier()
+    this_tot = 0
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size
+    )):
+        if ii % world_size != rank:
+            continue
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        batch_text = [get_prompt(s) for s in batch]
+        encodings = tokenizer(
+            batch_text,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"].cuda()
+        attention_mask = encodings["attention_mask"].cuda()
+        skip_special_tokens = True
+        if hasattr(model, "legacy") and model.legacy and "opt" in lang_encoder_name:
+            if rank == 0:
+                tqdm.write("use legacy model")
+            for i in range(len(input_ids)):
+                media_token_index = (input_ids[i] == media_token_id).nonzero()[0,0]
+                endofmedia_token_index = (input_ids[i] == endofmedia_token_id).nonzero()[0,0]
+                input_ids[i, media_token_index - 1] = media_token_id
+                input_ids[i, media_token_index] = pad_token_id
+                input_ids[i, endofmedia_token_index - 1] = endofmedia_token_id
+                input_ids[i, endofmedia_token_index] = bos_token_id
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        if "llama" in lang_encoder_name:
+            attention_mask[input_ids == 0] = 0
+        outputs = get_outputs(
+            model=model,
+            batch_images=batch_images,
+            attention_mask=attention_mask,
+            max_generation_length=10,
+            min_generation_length=1,
+            num_beams=5,
+            length_penalty=0,
+            input_ids=input_ids,
+            image_start_index_list=image_start_index_list,
+            image_nums=image_nums,
+        )
+        # postprocess begin
+        new_predictions = [
+            out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=skip_special_tokens)
+        ]
+        if vqa_dataset == "ok_vqa":
+            new_predictions = postprocessor._lemmatize(new_predictions)
+        if model.special:
+            for i in range(len(new_predictions)):
+                for answer, _ in Counter(batch[i]['answers']).most_common():
+                    if answer in new_predictions[i]:
+                        new_predictions[i] = answer
+                        break
+                    if "cant" in new_predictions[i] and "no" == answer:
+                        new_predictions[i] = answer
+                        break
+                    if "can" in new_predictions[i] and "not" not in new_predictions[i] and "cant" not in new_predictions[i] and "yes" == answer:
+                        new_predictions[i] = answer
+                        break
+        
+        this_tot += 1
+        if rank == 0 and this_tot % 20 == 0:
+            for i in range(1):
+                tqdm.write(f"question: {batch[i]['question']}\nanswer: {batch[i]['answers']}model output: " + new_predictions[i])
+
+        predictions.extend(
+            [
+                {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
+                for p, sample in zip(new_predictions, batch)
+            ]
+        )
+    with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps(predictions))
+    print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
+
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+            predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
+            os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+        print("num:", len(predictions))
+        # save the predictions to a temporary file
+        random_uuid = str(uuid.uuid4())
+        with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
+            f.write(json.dumps(predictions, indent=4))
+
+        if vqa_dataset == "gqa":
+            acc = compute_gqa_accuracy(predictions)
+        else:
+            acc = compute_vqa_accuracy(
+                f"{vqa_dataset}results_{random_uuid}.json",
+                questions_json_path,
+                annotations_json_path,
+                vqa_dataset=vqa_dataset,
+            )
+        print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
+        os.makedirs("eval_results", exist_ok=True)
+        with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(f"{vqa_dataset}results_{random_uuid}.json")
+    else:
+        time.sleep(5)
+        acc = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return acc
+
+
+def evaluate_refcoco(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    model.eval().cuda()
+    loc_token_ids = []
+    for i in range(1000):
+        loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    # all_ids = set(range(model.lang_encoder.lm_head.out_features))
+    # bad_words_ids = list(all_ids - set(loc_token_ids))
+    # bad_words_ids = [[b] for b in bad_words_ids]
+    # min_loc_token_id = min(loc_token_ids)
+    # max_loc_token_id = max(loc_token_ids)
+    total = 0
+    correct = 0
+    ious = []
+    if "refcocog" in tsvfile:
+        dataset_name = "refcocog"
+    elif "refcocoplus" in tsvfile:
+        dataset_name = "refcocoplus"
+    else:
+        dataset_name = "refcoco"
+    with open(tsvfile, "r") as f:
+        lines = f.readlines()
+        pbar = tqdm(lines, disable=(rank != 0))
+        for ii, line in enumerate(pbar):
+            if ii % world_size != rank:
+                continue
+            total += 1
+            line = line.rstrip()
+            uniq_id, image_id, text, region_coord, image = line.split("\t")
+
+            image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/cat.png").convert("RGB")
+            # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/temp/262148000.png")
+
+            gt_box = np.array(list(map(float, region_coord.split(","))))
+            width = image.width
+            height = image.height
+            image = image.resize((224, 224))
+            gt_box = gt_box / np.array([width, height, width, height]) * 224
+            batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>{text.rstrip('.').strip()}<|#endofobject#|><|#visual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>the cat<|#visual#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"]
+            # prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>a man<|#visual#|> is doing a trick on a skateboard<|#visual#|>"]
+
+
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            # attention_mask[input_ids == prebox_token_id] = 0
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+
+            model.debug_id = 0
+            with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=None,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=None,
+                    add_box=False,
+                )
+            boxes = outputs["boxes"]
+            scores = outputs["scores"]
+            if len(scores) > 0:
+                box = boxes[scores.argmax()]
+                iou = get_iou(box, gt_box)
+            else:
+                iou = 0.0
+                # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
+                tqdm.write(f"no output for: {uniq_id}, {image_id}, {text}")
+            if iou >= 0.5:
+                correct += 1
+            pbar.set_description(f"iou: {iou:.2f} score: {correct / total:.4f}")
+            # open_cv_image = np.array(image)
+            # # Convert RGB to BGR 
+            # open_cv_image = open_cv_image[:, :, ::-1].copy() 
+            # for box, score in zip(boxes, scores):
+            #     open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (255, 0, 0), 2)
+            # cv2.imwrite("output.jpg", open_cv_image)
+            # print(boxes)
+            # print(scores)
+            # exit()
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+def preprocess_visual_info(Text):
+    text = Text.split(" ")
+    for is_idx, t in enumerate(text):
+        if t == "is":
+            break
+    the_idx = is_idx
+    while text[the_idx] != "the":
+        the_idx -= 1
+    obj_A = " ".join(text[the_idx+1:is_idx])
+    second_the_idx = len(text) - 1
+    while text[second_the_idx] != "the":
+        second_the_idx -= 1
+    obj_B = " ".join(text[second_the_idx+1:])
+    relation = " ".join(text[is_idx+1:second_the_idx])
+    visual_obj_A = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>"
+    visual_obj_B = f"<|#object#|><|#previsual#|><|#prebox#|><|#object#|>the {obj_B}<|#endofobject#|>"
+    Text = f"{visual_obj_A} is {relation} {visual_obj_B}"
+    return Text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation
+
+
+
+def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox, debug=False, return_all=False):
+    assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
+    encodings = tokenizer(
+        prompt,
+        padding="longest",
+        truncation=True,
+        return_tensors="pt",
+        max_length=2000,
+    )
+    input_ids = encodings["input_ids"]
+    attention_mask = encodings["attention_mask"]
+    image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+    image_start_index_list = [[x] for x in image_start_index_list]
+    image_nums = [1] * len(input_ids)
+    vision_x = batch_images.cuda()
+    lang_x = input_ids.cuda()
+    attention_mask = attention_mask.cuda()
+    prebox_mask = (input_ids == prebox_token_id)
+    if mask_prebox and prebox_mask.any():
+        attention_mask[prebox_mask] = 0
+
+    model.debug_id = 0
+    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+        outputs = model(
+            vision_x=vision_x,
+            lang_x=lang_x,
+            attention_mask=attention_mask,
+            labels=None,
+            image_nums=image_nums,
+            image_start_index_list=image_start_index_list,
+            added_bbox_list=visual_box_list,
+            add_box=visual_box_list is not None,
+            relations=None,
+            debug_mode=False,
+        )
+    boxes = outputs["boxes"]
+    scores = outputs["scores"]
+    if debug:
+        import pdb; pdb.set_trace()
+    if return_all:
+        return boxes, scores
+    if len(scores) == 0:
+        return None, None
+    else:
+        return boxes[scores.argmax()], scores.max()
+
+
+def evaluate_aro(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    add_visual=True,
+    add_relation=False,
+    subset=True,
+    choose_left_right=True,
+):
+    os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
+    from groundingdino.demo.caption_grounder import caption_grounder
+    generator = caption_grounder(
+        config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
+        checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
+        cpu_only=False,
+        box_threshold=0.1, text_threshold=0.1,
+    )
+    dataset_name = "aro"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    model.eval().cuda()
+    total = 0
+    correct = 0
+    from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
+    vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
+    with open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/unilm/kosmos-2/labels.json") as f:
+        all_labels = json.load(f)
+        label_ids = tokenizer(all_labels).input_ids
+        label_ids = sorted(list(set([x[0] for x in label_ids])))
+
+    if subset:
+        subset_idx = json.load(open("aro_subset.json"))
+        pbar = tqdm(subset_idx, disable=(rank != 0))
+    else:
+        pbar = tqdm(vgr_dataset, disable=(rank != 0))
+    
+
+    exist_total = 0
+    for ii, sample in enumerate(pbar):
+        if subset:
+            ORI_IDX = int(sample)
+            sample = vgr_dataset[sample]
+            # if ORI_IDX != 19036:
+            #     continue
+        if ii % world_size != rank:
+            continue
+
+        not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
+        if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
+            if rank == 0:
+                tqdm.write(f"SKIP: {sample['caption_options'][1]}")
+            continue
+        total += 1
+        image = sample["image_options"][0]
+        # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
+        image = image.resize((224, 224))
+
+        chosen_idx = 0
+        text = sample["caption_options"][chosen_idx] # 1 is true caption
+        # text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
+        batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
+
+
+        first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
+        first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox=True, return_all=False)
+
+
+        # use grounding DINO to get the first bbox
+        # caption = f"{obj_A}"
+        # with torch.no_grad():
+        #     logits, boxes = generator.ground_caption_raw(image_pil=image, caption=caption)
+        #     boxes_filt, pred_phrases = generator.postprocess(logits, boxes, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
+        # objects = {}
+        # for box, phrase in zip(boxes_filt, pred_phrases):
+        #     obj, score = phrase
+        #     obj = obj[0]
+        #     if obj not in objects:
+        #         objects[obj] = (score, box)
+        #     if objects[obj][0] < score:
+        #         objects[obj] = (score, box)
+        # try:
+        #     first_box = objects[obj_A][1].clone()
+        #     first_box[:2] -= first_box[2:] / 2
+        #     first_box[2:] += first_box[:2]
+        #     first_box = first_box.clamp(0, 0.99) * 224.0
+        #     first_box = first_box.numpy()
+        #     first_score = objects[obj_A][0]
+        # except:
+        #     first_box = None
+
+        if first_box is None:
+            text_A = "the " + obj_A
+            added_bbox_list = None
+        else:
+            text_A = visual_obj_A
+            added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
+
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
+        pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, 
+        prebox_token_id, mask_prebox=False, debug=False, return_all=True)
+
+
+        open_cv_image = np.array(image)
+        open_cv_image = open_cv_image[:, :, ::-1].copy()
+        font = cv2.FONT_HERSHEY_SIMPLEX
+        fontScale = 0.5
+        color = (0, 0, 0)
+        thickness = 1
+        if first_box is not None:
+            open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
+        exist_flag = False
+        for box, score in zip(pre_boxes, pre_scores):
+            if score >= 0.5:
+                exist_flag = True
+                open_cv_image = cv2.rectangle(open_cv_image, box[:2].astype(int), box[2:].astype(int), (0, 255, 0), 2)
+                org = box[:2].astype(int)
+                org[1] += 20
+                org[0] += 10
+                open_cv_image = cv2.putText(open_cv_image, f"{score:.2f}", org, font, fontScale, (255, 255, 255), thickness, cv2.LINE_AA)
+        open_cv_image = cv2.resize(open_cv_image, (512, 512))
+        put_text = sample["caption_options"][chosen_idx]
+        org = [10, 20]
+        open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
+        # cv2.imwrite(f"visualization/aro_results_{id}/{str(ORI_IDX).zfill(8)}.jpg", open_cv_image)
+        if exist_flag:
+            exist_total += 1
+        continue
+
+
+
+        if pre_boxes is None:
+            pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
+            pre_scores = [1.0]
+
+        rank_list = []
+        # pre_boxes = [pre_boxes[0]]
+        # pre_scores = [pre_scores[0]]
+        for pre_box, pre_score in zip(pre_boxes, pre_scores):
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
+
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=512,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+            labels = lang_x.clone()
+
+            answer_start_idx = (labels == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] + 1
+            # pre_box = None
+            labels[0, :answer_start_idx] = -100
+            # # labels[labels == endofobject_token_id] = -100
+            # labels[:, 0] = -100
+            # labels[labels == visual_token_id] = -100
+            # labels[labels == box_token_id] = -100
+            # labels[labels == previsual_token_id] = -100
+            # labels[labels == prebox_token_id] = -100
+            # labels[labels == endofattr_token_id] = -100
+            # labels[labels == tokenizer.pad_token_id] = -100
+            # labels[labels == media_token_id] = -100
+            # labels[labels == endofmedia_token_id] = -100
+            answer_ids = tokenizer(f" {obj_B}", add_special_tokens=False)["input_ids"]
+            labels[input_ids == visual_token_id] = -100
+            labels[input_ids == box_token_id] = -100
+            labels[input_ids == endofattr_token_id] = -100
+            labels[input_ids == previsual_token_id] = -100
+            labels[input_ids == prebox_token_id] = -100
+            labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
+            labels[torch.roll(input_ids == box_token_id, 1)] = -100
+            labels[:, 0] = -100
+            labels[input_ids == tokenizer.pad_token_id] = -100
+            labels[input_ids == media_token_id] = -100
+            labels[input_ids == endofmedia_token_id] = -100
+
+            added_bbox_list = None
+            if add_visual:
+                added_bbox_list = []
+                if first_box is not None:
+                    added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
+                if pre_box is not None:
+                    added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
+            if added_bbox_list is not None and len(added_bbox_list) == 0:
+                added_bbox_list = None
+
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=labels,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=added_bbox_list,
+                    add_box=added_bbox_list is not None,
+                    relations=None,
+                )
+            logits = outputs["logits"][0, answer_start_idx:]
+            _rank = logits[0][label_ids].sort(descending=True).indices.tolist().index(label_ids.index(answer_ids[0]))
+            rank_list.append(_rank)
+            # open_cv_image = np.array(image)
+            # open_cv_image = open_cv_image[:, :, ::-1].copy()
+            # if first_box is not None:
+            #     open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
+            # if pre_box is not None:
+            #     open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
+
+            # font = cv2.FONT_HERSHEY_SIMPLEX
+            # org = [10, 20]
+            # fontScale = 0.5
+            # color = (0, 0, 0)
+            # thickness = 1
+            # open_cv_image = cv2.resize(open_cv_image, (512, 512))
+            # put_text = sample["caption_options"][1]
+            # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
+            # org[1] += 20
+            # put_text = "top10 in green box"
+            # open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
+            # fontScale = 1.0
+            # thickness = 2
+            # for ind in logits_list[i][0].sort(descending=True).indices[:10]:
+            #     org[1] += 20
+            #     put_text = f"{tokenizer.decode(ind)}"
+            #     open_cv_image = cv2.putText(open_cv_image, put_text, org, font, fontScale, color, thickness, cv2.LINE_AA)
+            # tqdm.write(f"{tokenizer.decode(logits_list[i][0].sort(descending=True).indices[:10])}")
+        # tqdm.write(f"{rank_list}")
+        final_rank = min(rank_list)
+        if final_rank < 10:
+            correct += 1
+            TYPE = "CORRECT"
+            if rank == 0:
+                tqdm.write(f"correct: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
+        else:
+            TYPE = "WRONG"
+            if rank == 0:
+                tqdm.write(f"wrong: {final_rank} " + prompt[0].replace(tokenizer.pad_token, ""))
+        # cv2.imwrite(f"visualization/aro_results_{id}/{TYPE}_{ORI_IDX}.jpg", open_cv_image)
+        pbar.set_description(f"score: {correct / total:.4f} | {final_rank}")
+
+
+
+
+
+    print(exist_total)
+    exit()
+
+
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score, "total:", total)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+
+
+def evaluate_aro_ori(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    add_visual=True,
+    add_relation=False,
+    subset=True,
+    choose_left_right=True,
+    only_highest=True,
+):
+    os.makedirs(f"visualization/aro_results_{id}", exist_ok=True)
+    dataset_name = "aroori"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    model.eval().cuda()
+    total = 0
+    correct = 0
+    from open_flamingo.eval.dataset_zoo import VG_Relation, VG_Attribution
+    vgr_dataset = VG_Relation(image_preprocess=None, download=True, root_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/vision-language-models-are-bows/data")
+    if subset:
+        subset_idx = json.load(open("aro_subset.json"))
+        pbar = tqdm(subset_idx, disable=(rank != 0))
+    else:
+        pbar = tqdm(vgr_dataset, disable=(rank != 0))
+    for ii, sample in enumerate(pbar):
+        if subset:
+            ORI_IDX = int(sample)
+            sample = vgr_dataset[sample]
+            # if ORI_IDX != 19036:
+            #     continue
+        if ii % world_size != rank:
+            continue
+
+        not_left_right = ("near" in sample["caption_options"][0] or "next to" in sample["caption_options"][0] or "in front of" in sample["caption_options"][0] or "behind" in sample["caption_options"][0]) or ("left" not in sample["caption_options"][0] and "right" not in sample["caption_options"][0])
+        if (choose_left_right and not_left_right) or (not choose_left_right and not not_left_right):
+            if rank == 0:
+                tqdm.write(f"SKIP: {sample['caption_options'][1]}")
+            continue
+        total += 1
+        image = sample["image_options"][0]
+        # image = Image.open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/yolo.png").convert("RGB")
+        image = image.resize((224, 224))
+        debug_data = []
+        final_losses = []
+        for idx in range(2):
+            text = sample["caption_options"][idx] # 1 is true caption
+            # text = "the dog is sitting on the floor" if idx == 1 else "the floor is sitting on the dog"
+            batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+            text, obj_A, visual_obj_A, obj_B, visual_obj_B, relation = preprocess_visual_info(text)
+            first_text = f"<|#object#|>the {obj_A}<|#endofobject#|><|#visual#|>"
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
+            first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, mask_prebox=True, return_all=False)
+            if first_box is None:
+                text_A = "the " + obj_A
+                added_bbox_list = None
+            else:
+                text_A = visual_obj_A
+                added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
+
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|>"]
+            pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, 
+            prebox_token_id, mask_prebox=False, debug=False, return_all=True)
+            if pre_boxes is None:
+                pre_boxes = [np.array([0.0, 0.0, 223.0, 223.0])]
+                pre_scores = [1.0]
+
+            loss_list = []
+            if only_highest:
+                pre_boxes = [pre_boxes[0]]
+                pre_scores = [pre_scores[0]]
+            for pre_box, pre_score in zip(pre_boxes, pre_scores):
+                prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text_A} is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {obj_B}<|#endofobject#|>"]
+
+                encodings = tokenizer(
+                    prompt,
+                    padding="longest",
+                    truncation=True,
+                    return_tensors="pt",
+                    max_length=512,
+                )
+                input_ids = encodings["input_ids"]
+                attention_mask = encodings["attention_mask"]
+                image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+                image_start_index_list = [[x] for x in image_start_index_list]
+                image_nums = [1] * len(input_ids)
+                vision_x = batch_images.cuda()
+                lang_x = input_ids.cuda()
+                attention_mask = attention_mask.cuda()
+                labels = lang_x.clone()
+
+
+                labels[input_ids == visual_token_id] = -100
+                labels[input_ids == box_token_id] = -100
+                labels[input_ids == endofattr_token_id] = -100
+                labels[input_ids == previsual_token_id] = -100
+                labels[input_ids == prebox_token_id] = -100
+                labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
+                labels[torch.roll(input_ids == box_token_id, 1)] = -100
+                labels[:, 0] = -100
+                labels[input_ids == tokenizer.pad_token_id] = -100
+                labels[input_ids == media_token_id] = -100
+                labels[input_ids == endofmedia_token_id] = -100
+
+                added_bbox_list = None
+                if add_visual:
+                    added_bbox_list = []
+                    if first_box is not None:
+                        added_bbox_list.append(torch.tensor(first_box).unsqueeze(0).cuda().float() / 224)
+                    if pre_box is not None:
+                        added_bbox_list.append(torch.tensor(pre_box).unsqueeze(0).cuda().float() / 224)
+                if added_bbox_list is not None and len(added_bbox_list) == 0:
+                    added_bbox_list = None
+
+                with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                    outputs = model(
+                        vision_x=vision_x,
+                        lang_x=lang_x,
+                        attention_mask=attention_mask,
+                        labels=labels,
+                        image_nums=image_nums,
+                        image_start_index_list=image_start_index_list,
+                        added_bbox_list=added_bbox_list,
+                        add_box=added_bbox_list is not None,
+                        relations=None,
+                    )
+                loss_list.append((outputs["loss"].sum() / (outputs["loss"] != 0).sum()).item())
+                debug_data.append([outputs, first_box, first_score, pre_box, pre_scores])
+            final_loss = min(loss_list)
+            final_losses.append(final_loss)
+        if final_losses[0] >= final_losses[1]:
+            correct += 1
+        else:
+            import pdb; pdb.set_trace()
+            pass
+        pbar.set_description(f"score: {correct / total:.4f} | {final_losses[0]:.2f} vs {final_losses[1]:.2f}")
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score, "total:", total)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+def evaluate_pisc(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    tsvfile,
+    max_generation_length=20,
+    num_beams=3,
+    length_penalty=-2.0,
+    device=-1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    add_visual=True,
+):
+    from open_flamingo.train.instruction_template import PISC_TEMPLATES
+    dataset_name = "pisc"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    model.train().cuda()
+
+    dataset = wds.WebDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc/000000.tar").decode().to_tuple("image_path.txt", "dataset.txt", "data.pyd")
+    pbar = tqdm(dataset, disable=(rank != 0))
+
+    rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
+    rel_type_to_id = {x: i for i, x in enumerate(rel_id_to_type)}
+    gt = []
+    pred_scores = []
+    for III, sample in enumerate(pbar):
+        if III % world_size != rank:
+            continue
+        image_path, dataset, data = sample
+        image = Image.open(image_path)
+        size = image_processor.transforms[0].size
+        image = image.resize((size, size))
+        batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        boxA = data[0]
+        boxB = data[1]
+        gt_relation = data[2]
+        losses = []
+        for i_rel, option_rel in enumerate(rel_id_to_type):
+            text = PISC_TEMPLATES[0].format(relation=option_rel)
+            added_bbox = [
+                torch.tensor([boxA]).cuda(),
+                torch.tensor([boxB]).cuda(),
+            ]
+            caption = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{text}{tokenizer.eos_token}"
+            encodings = tokenizer(
+                caption,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+
+            labels = lang_x.clone()
+            labels[labels == tokenizer.pad_token_id] = -100
+            if add_visual:
+                # endofattr_next_token_index = list((labels == endofattr_token_id).nonzero(as_tuple=True))
+                # endofattr_next_token_index[1] += 1
+                # endofattr_next_token_id = labels[endofattr_next_token_index]
+                # </obj><visual><box></attr>NEXT_WORD
+                # </obj> predict NEXT_WORD
+                # <visual><box></attr> predict nothing
+                labels[labels == visual_token_id] = -100
+                labels[labels == box_token_id] = -100
+                labels[labels == endofattr_token_id] = -100
+                # labels[endofattr_next_token_index] = -100
+            labels[:, 0] = -100
+            answer_token_id = tokenizer(" Answer").input_ids[0]
+            answer_token_loc = (input_ids == answer_token_id).nonzero()
+            for batch_idx, idx in answer_token_loc:
+                labels[batch_idx][:idx+2] = -100
+
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=labels,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=added_bbox,
+                    add_box=added_bbox is not None,
+                )
+                loss_total = outputs.loss.reshape(labels.shape[0], -1)
+                loss = loss_total.sum() / (loss_total != 0).sum()
+                losses.append(loss.item())
+        pred_scores.append(np.exp(-np.array(losses)) / np.exp(-np.array(losses)).sum())
+        gt.append(rel_type_to_id[gt_relation])
+    gt = np.array(gt)
+    pred_scores = np.array(pred_scores)
+    pred = pred_scores.argmax(1)
+
+
+    print("total num:", len(gt))
+    recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
+    print("recalls:", recalls)
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([gt.tolist(), pred.tolist()]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        gt = []
+        pred = []
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [gt_part, pred_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            gt.extend(gt_part)
+            pred.extend(pred_part)
+        print("total num:", len(gt))
+        recalls = recall_score(y_true=gt, y_pred=pred, average=None, labels=[0,1,2,3,4,5])
+        print("recalls:", recalls)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}"), "w") as f:
+            f.write(f"{gt}\n")
+            f.write(f"{pred}\n")
+            f.write(f"{recalls}\n")
+    score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/open_flamingo/eval/imagenet_utils.py b/multimodal/open_flamingo/eval/imagenet_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5803c700249335b8ecf947ca81d543d51c9cbbb5
--- /dev/null
+++ b/multimodal/open_flamingo/eval/imagenet_utils.py
@@ -0,0 +1,1007 @@
+# classnames via https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/imagenet_classnames.py#L1
+openai_imagenet_classnames = [
+    "tench",
+    "goldfish",
+    "great white shark",
+    "tiger shark",
+    "hammerhead shark",
+    "electric ray",
+    "stingray",
+    "rooster",
+    "hen",
+    "ostrich",
+    "brambling",
+    "goldfinch",
+    "house finch",
+    "junco",
+    "indigo bunting",
+    "American robin",
+    "bulbul",
+    "jay",
+    "magpie",
+    "chickadee",
+    "American dipper",
+    "kite (bird of prey)",
+    "bald eagle",
+    "vulture",
+    "great grey owl",
+    "fire salamander",
+    "smooth newt",
+    "newt",
+    "spotted salamander",
+    "axolotl",
+    "American bullfrog",
+    "tree frog",
+    "tailed frog",
+    "loggerhead sea turtle",
+    "leatherback sea turtle",
+    "mud turtle",
+    "terrapin",
+    "box turtle",
+    "banded gecko",
+    "green iguana",
+    "Carolina anole",
+    "desert grassland whiptail lizard",
+    "agama",
+    "frilled-necked lizard",
+    "alligator lizard",
+    "Gila monster",
+    "European green lizard",
+    "chameleon",
+    "Komodo dragon",
+    "Nile crocodile",
+    "American alligator",
+    "triceratops",
+    "worm snake",
+    "ring-necked snake",
+    "eastern hog-nosed snake",
+    "smooth green snake",
+    "kingsnake",
+    "garter snake",
+    "water snake",
+    "vine snake",
+    "night snake",
+    "boa constrictor",
+    "African rock python",
+    "Indian cobra",
+    "green mamba",
+    "sea snake",
+    "Saharan horned viper",
+    "eastern diamondback rattlesnake",
+    "sidewinder rattlesnake",
+    "trilobite",
+    "harvestman",
+    "scorpion",
+    "yellow garden spider",
+    "barn spider",
+    "European garden spider",
+    "southern black widow",
+    "tarantula",
+    "wolf spider",
+    "tick",
+    "centipede",
+    "black grouse",
+    "ptarmigan",
+    "ruffed grouse",
+    "prairie grouse",
+    "peafowl",
+    "quail",
+    "partridge",
+    "african grey parrot",
+    "macaw",
+    "sulphur-crested cockatoo",
+    "lorikeet",
+    "coucal",
+    "bee eater",
+    "hornbill",
+    "hummingbird",
+    "jacamar",
+    "toucan",
+    "duck",
+    "red-breasted merganser",
+    "goose",
+    "black swan",
+    "tusker",
+    "echidna",
+    "platypus",
+    "wallaby",
+    "koala",
+    "wombat",
+    "jellyfish",
+    "sea anemone",
+    "brain coral",
+    "flatworm",
+    "nematode",
+    "conch",
+    "snail",
+    "slug",
+    "sea slug",
+    "chiton",
+    "chambered nautilus",
+    "Dungeness crab",
+    "rock crab",
+    "fiddler crab",
+    "red king crab",
+    "American lobster",
+    "spiny lobster",
+    "crayfish",
+    "hermit crab",
+    "isopod",
+    "white stork",
+    "black stork",
+    "spoonbill",
+    "flamingo",
+    "little blue heron",
+    "great egret",
+    "bittern bird",
+    "crane bird",
+    "limpkin",
+    "common gallinule",
+    "American coot",
+    "bustard",
+    "ruddy turnstone",
+    "dunlin",
+    "common redshank",
+    "dowitcher",
+    "oystercatcher",
+    "pelican",
+    "king penguin",
+    "albatross",
+    "grey whale",
+    "killer whale",
+    "dugong",
+    "sea lion",
+    "Chihuahua",
+    "Japanese Chin",
+    "Maltese",
+    "Pekingese",
+    "Shih Tzu",
+    "King Charles Spaniel",
+    "Papillon",
+    "toy terrier",
+    "Rhodesian Ridgeback",
+    "Afghan Hound",
+    "Basset Hound",
+    "Beagle",
+    "Bloodhound",
+    "Bluetick Coonhound",
+    "Black and Tan Coonhound",
+    "Treeing Walker Coonhound",
+    "English foxhound",
+    "Redbone Coonhound",
+    "borzoi",
+    "Irish Wolfhound",
+    "Italian Greyhound",
+    "Whippet",
+    "Ibizan Hound",
+    "Norwegian Elkhound",
+    "Otterhound",
+    "Saluki",
+    "Scottish Deerhound",
+    "Weimaraner",
+    "Staffordshire Bull Terrier",
+    "American Staffordshire Terrier",
+    "Bedlington Terrier",
+    "Border Terrier",
+    "Kerry Blue Terrier",
+    "Irish Terrier",
+    "Norfolk Terrier",
+    "Norwich Terrier",
+    "Yorkshire Terrier",
+    "Wire Fox Terrier",
+    "Lakeland Terrier",
+    "Sealyham Terrier",
+    "Airedale Terrier",
+    "Cairn Terrier",
+    "Australian Terrier",
+    "Dandie Dinmont Terrier",
+    "Boston Terrier",
+    "Miniature Schnauzer",
+    "Giant Schnauzer",
+    "Standard Schnauzer",
+    "Scottish Terrier",
+    "Tibetan Terrier",
+    "Australian Silky Terrier",
+    "Soft-coated Wheaten Terrier",
+    "West Highland White Terrier",
+    "Lhasa Apso",
+    "Flat-Coated Retriever",
+    "Curly-coated Retriever",
+    "Golden Retriever",
+    "Labrador Retriever",
+    "Chesapeake Bay Retriever",
+    "German Shorthaired Pointer",
+    "Vizsla",
+    "English Setter",
+    "Irish Setter",
+    "Gordon Setter",
+    "Brittany dog",
+    "Clumber Spaniel",
+    "English Springer Spaniel",
+    "Welsh Springer Spaniel",
+    "Cocker Spaniel",
+    "Sussex Spaniel",
+    "Irish Water Spaniel",
+    "Kuvasz",
+    "Schipperke",
+    "Groenendael dog",
+    "Malinois",
+    "Briard",
+    "Australian Kelpie",
+    "Komondor",
+    "Old English Sheepdog",
+    "Shetland Sheepdog",
+    "collie",
+    "Border Collie",
+    "Bouvier des Flandres dog",
+    "Rottweiler",
+    "German Shepherd Dog",
+    "Dobermann",
+    "Miniature Pinscher",
+    "Greater Swiss Mountain Dog",
+    "Bernese Mountain Dog",
+    "Appenzeller Sennenhund",
+    "Entlebucher Sennenhund",
+    "Boxer",
+    "Bullmastiff",
+    "Tibetan Mastiff",
+    "French Bulldog",
+    "Great Dane",
+    "St. Bernard",
+    "husky",
+    "Alaskan Malamute",
+    "Siberian Husky",
+    "Dalmatian",
+    "Affenpinscher",
+    "Basenji",
+    "pug",
+    "Leonberger",
+    "Newfoundland dog",
+    "Great Pyrenees dog",
+    "Samoyed",
+    "Pomeranian",
+    "Chow Chow",
+    "Keeshond",
+    "brussels griffon",
+    "Pembroke Welsh Corgi",
+    "Cardigan Welsh Corgi",
+    "Toy Poodle",
+    "Miniature Poodle",
+    "Standard Poodle",
+    "Mexican hairless dog (xoloitzcuintli)",
+    "grey wolf",
+    "Alaskan tundra wolf",
+    "red wolf or maned wolf",
+    "coyote",
+    "dingo",
+    "dhole",
+    "African wild dog",
+    "hyena",
+    "red fox",
+    "kit fox",
+    "Arctic fox",
+    "grey fox",
+    "tabby cat",
+    "tiger cat",
+    "Persian cat",
+    "Siamese cat",
+    "Egyptian Mau",
+    "cougar",
+    "lynx",
+    "leopard",
+    "snow leopard",
+    "jaguar",
+    "lion",
+    "tiger",
+    "cheetah",
+    "brown bear",
+    "American black bear",
+    "polar bear",
+    "sloth bear",
+    "mongoose",
+    "meerkat",
+    "tiger beetle",
+    "ladybug",
+    "ground beetle",
+    "longhorn beetle",
+    "leaf beetle",
+    "dung beetle",
+    "rhinoceros beetle",
+    "weevil",
+    "fly",
+    "bee",
+    "ant",
+    "grasshopper",
+    "cricket insect",
+    "stick insect",
+    "cockroach",
+    "praying mantis",
+    "cicada",
+    "leafhopper",
+    "lacewing",
+    "dragonfly",
+    "damselfly",
+    "red admiral butterfly",
+    "ringlet butterfly",
+    "monarch butterfly",
+    "small white butterfly",
+    "sulphur butterfly",
+    "gossamer-winged butterfly",
+    "starfish",
+    "sea urchin",
+    "sea cucumber",
+    "cottontail rabbit",
+    "hare",
+    "Angora rabbit",
+    "hamster",
+    "porcupine",
+    "fox squirrel",
+    "marmot",
+    "beaver",
+    "guinea pig",
+    "common sorrel horse",
+    "zebra",
+    "pig",
+    "wild boar",
+    "warthog",
+    "hippopotamus",
+    "ox",
+    "water buffalo",
+    "bison",
+    "ram (adult male sheep)",
+    "bighorn sheep",
+    "Alpine ibex",
+    "hartebeest",
+    "impala (antelope)",
+    "gazelle",
+    "arabian camel",
+    "llama",
+    "weasel",
+    "mink",
+    "European polecat",
+    "black-footed ferret",
+    "otter",
+    "skunk",
+    "badger",
+    "armadillo",
+    "three-toed sloth",
+    "orangutan",
+    "gorilla",
+    "chimpanzee",
+    "gibbon",
+    "siamang",
+    "guenon",
+    "patas monkey",
+    "baboon",
+    "macaque",
+    "langur",
+    "black-and-white colobus",
+    "proboscis monkey",
+    "marmoset",
+    "white-headed capuchin",
+    "howler monkey",
+    "titi monkey",
+    "Geoffroy's spider monkey",
+    "common squirrel monkey",
+    "ring-tailed lemur",
+    "indri",
+    "Asian elephant",
+    "African bush elephant",
+    "red panda",
+    "giant panda",
+    "snoek fish",
+    "eel",
+    "silver salmon",
+    "rock beauty fish",
+    "clownfish",
+    "sturgeon",
+    "gar fish",
+    "lionfish",
+    "pufferfish",
+    "abacus",
+    "abaya",
+    "academic gown",
+    "accordion",
+    "acoustic guitar",
+    "aircraft carrier",
+    "airliner",
+    "airship",
+    "altar",
+    "ambulance",
+    "amphibious vehicle",
+    "analog clock",
+    "apiary",
+    "apron",
+    "trash can",
+    "assault rifle",
+    "backpack",
+    "bakery",
+    "balance beam",
+    "balloon",
+    "ballpoint pen",
+    "Band-Aid",
+    "banjo",
+    "baluster / handrail",
+    "barbell",
+    "barber chair",
+    "barbershop",
+    "barn",
+    "barometer",
+    "barrel",
+    "wheelbarrow",
+    "baseball",
+    "basketball",
+    "bassinet",
+    "bassoon",
+    "swimming cap",
+    "bath towel",
+    "bathtub",
+    "station wagon",
+    "lighthouse",
+    "beaker",
+    "military hat (bearskin or shako)",
+    "beer bottle",
+    "beer glass",
+    "bell tower",
+    "baby bib",
+    "tandem bicycle",
+    "bikini",
+    "ring binder",
+    "binoculars",
+    "birdhouse",
+    "boathouse",
+    "bobsleigh",
+    "bolo tie",
+    "poke bonnet",
+    "bookcase",
+    "bookstore",
+    "bottle cap",
+    "hunting bow",
+    "bow tie",
+    "brass memorial plaque",
+    "bra",
+    "breakwater",
+    "breastplate",
+    "broom",
+    "bucket",
+    "buckle",
+    "bulletproof vest",
+    "high-speed train",
+    "butcher shop",
+    "taxicab",
+    "cauldron",
+    "candle",
+    "cannon",
+    "canoe",
+    "can opener",
+    "cardigan",
+    "car mirror",
+    "carousel",
+    "tool kit",
+    "cardboard box / carton",
+    "car wheel",
+    "automated teller machine",
+    "cassette",
+    "cassette player",
+    "castle",
+    "catamaran",
+    "CD player",
+    "cello",
+    "mobile phone",
+    "chain",
+    "chain-link fence",
+    "chain mail",
+    "chainsaw",
+    "storage chest",
+    "chiffonier",
+    "bell or wind chime",
+    "china cabinet",
+    "Christmas stocking",
+    "church",
+    "movie theater",
+    "cleaver",
+    "cliff dwelling",
+    "cloak",
+    "clogs",
+    "cocktail shaker",
+    "coffee mug",
+    "coffeemaker",
+    "spiral or coil",
+    "combination lock",
+    "computer keyboard",
+    "candy store",
+    "container ship",
+    "convertible",
+    "corkscrew",
+    "cornet",
+    "cowboy boot",
+    "cowboy hat",
+    "cradle",
+    "construction crane",
+    "crash helmet",
+    "crate",
+    "infant bed",
+    "Crock Pot",
+    "croquet ball",
+    "crutch",
+    "cuirass",
+    "dam",
+    "desk",
+    "desktop computer",
+    "rotary dial telephone",
+    "diaper",
+    "digital clock",
+    "digital watch",
+    "dining table",
+    "dishcloth",
+    "dishwasher",
+    "disc brake",
+    "dock",
+    "dog sled",
+    "dome",
+    "doormat",
+    "drilling rig",
+    "drum",
+    "drumstick",
+    "dumbbell",
+    "Dutch oven",
+    "electric fan",
+    "electric guitar",
+    "electric locomotive",
+    "entertainment center",
+    "envelope",
+    "espresso machine",
+    "face powder",
+    "feather boa",
+    "filing cabinet",
+    "fireboat",
+    "fire truck",
+    "fire screen",
+    "flagpole",
+    "flute",
+    "folding chair",
+    "football helmet",
+    "forklift",
+    "fountain",
+    "fountain pen",
+    "four-poster bed",
+    "freight car",
+    "French horn",
+    "frying pan",
+    "fur coat",
+    "garbage truck",
+    "gas mask or respirator",
+    "gas pump",
+    "goblet",
+    "go-kart",
+    "golf ball",
+    "golf cart",
+    "gondola",
+    "gong",
+    "gown",
+    "grand piano",
+    "greenhouse",
+    "radiator grille",
+    "grocery store",
+    "guillotine",
+    "hair clip",
+    "hair spray",
+    "half-track",
+    "hammer",
+    "hamper",
+    "hair dryer",
+    "hand-held computer",
+    "handkerchief",
+    "hard disk drive",
+    "harmonica",
+    "harp",
+    "combine harvester",
+    "hatchet",
+    "holster",
+    "home theater",
+    "honeycomb",
+    "hook",
+    "hoop skirt",
+    "gymnastic horizontal bar",
+    "horse-drawn vehicle",
+    "hourglass",
+    "iPod",
+    "clothes iron",
+    "carved pumpkin",
+    "jeans",
+    "jeep",
+    "T-shirt",
+    "jigsaw puzzle",
+    "rickshaw",
+    "joystick",
+    "kimono",
+    "knee pad",
+    "knot",
+    "lab coat",
+    "ladle",
+    "lampshade",
+    "laptop computer",
+    "lawn mower",
+    "lens cap",
+    "letter opener",
+    "library",
+    "lifeboat",
+    "lighter",
+    "limousine",
+    "ocean liner",
+    "lipstick",
+    "slip-on shoe",
+    "lotion",
+    "music speaker",
+    "loupe magnifying glass",
+    "sawmill",
+    "magnetic compass",
+    "messenger bag",
+    "mailbox",
+    "tights",
+    "one-piece bathing suit",
+    "manhole cover",
+    "maraca",
+    "marimba",
+    "mask",
+    "matchstick",
+    "maypole",
+    "maze",
+    "measuring cup",
+    "medicine cabinet",
+    "megalith",
+    "microphone",
+    "microwave oven",
+    "military uniform",
+    "milk can",
+    "minibus",
+    "miniskirt",
+    "minivan",
+    "missile",
+    "mitten",
+    "mixing bowl",
+    "mobile home",
+    "ford model t",
+    "modem",
+    "monastery",
+    "monitor",
+    "moped",
+    "mortar and pestle",
+    "graduation cap",
+    "mosque",
+    "mosquito net",
+    "vespa",
+    "mountain bike",
+    "tent",
+    "computer mouse",
+    "mousetrap",
+    "moving van",
+    "muzzle",
+    "metal nail",
+    "neck brace",
+    "necklace",
+    "baby pacifier",
+    "notebook computer",
+    "obelisk",
+    "oboe",
+    "ocarina",
+    "odometer",
+    "oil filter",
+    "pipe organ",
+    "oscilloscope",
+    "overskirt",
+    "bullock cart",
+    "oxygen mask",
+    "product packet / packaging",
+    "paddle",
+    "paddle wheel",
+    "padlock",
+    "paintbrush",
+    "pajamas",
+    "palace",
+    "pan flute",
+    "paper towel",
+    "parachute",
+    "parallel bars",
+    "park bench",
+    "parking meter",
+    "railroad car",
+    "patio",
+    "payphone",
+    "pedestal",
+    "pencil case",
+    "pencil sharpener",
+    "perfume",
+    "Petri dish",
+    "photocopier",
+    "plectrum",
+    "Pickelhaube",
+    "picket fence",
+    "pickup truck",
+    "pier",
+    "piggy bank",
+    "pill bottle",
+    "pillow",
+    "ping-pong ball",
+    "pinwheel",
+    "pirate ship",
+    "drink pitcher",
+    "block plane",
+    "planetarium",
+    "plastic bag",
+    "plate rack",
+    "farm plow",
+    "plunger",
+    "Polaroid camera",
+    "pole",
+    "police van",
+    "poncho",
+    "pool table",
+    "soda bottle",
+    "plant pot",
+    "potter's wheel",
+    "power drill",
+    "prayer rug",
+    "printer",
+    "prison",
+    "missile",
+    "projector",
+    "hockey puck",
+    "punching bag",
+    "purse",
+    "quill",
+    "quilt",
+    "race car",
+    "racket",
+    "radiator",
+    "radio",
+    "radio telescope",
+    "rain barrel",
+    "recreational vehicle",
+    "fishing casting reel",
+    "reflex camera",
+    "refrigerator",
+    "remote control",
+    "restaurant",
+    "revolver",
+    "rifle",
+    "rocking chair",
+    "rotisserie",
+    "eraser",
+    "rugby ball",
+    "ruler measuring stick",
+    "sneaker",
+    "safe",
+    "safety pin",
+    "salt shaker",
+    "sandal",
+    "sarong",
+    "saxophone",
+    "scabbard",
+    "weighing scale",
+    "school bus",
+    "schooner",
+    "scoreboard",
+    "CRT monitor",
+    "screw",
+    "screwdriver",
+    "seat belt",
+    "sewing machine",
+    "shield",
+    "shoe store",
+    "shoji screen / room divider",
+    "shopping basket",
+    "shopping cart",
+    "shovel",
+    "shower cap",
+    "shower curtain",
+    "ski",
+    "balaclava ski mask",
+    "sleeping bag",
+    "slide rule",
+    "sliding door",
+    "slot machine",
+    "snorkel",
+    "snowmobile",
+    "snowplow",
+    "soap dispenser",
+    "soccer ball",
+    "sock",
+    "solar thermal collector",
+    "sombrero",
+    "soup bowl",
+    "keyboard space bar",
+    "space heater",
+    "space shuttle",
+    "spatula",
+    "motorboat",
+    "spider web",
+    "spindle",
+    "sports car",
+    "spotlight",
+    "stage",
+    "steam locomotive",
+    "through arch bridge",
+    "steel drum",
+    "stethoscope",
+    "scarf",
+    "stone wall",
+    "stopwatch",
+    "stove",
+    "strainer",
+    "tram",
+    "stretcher",
+    "couch",
+    "stupa",
+    "submarine",
+    "suit",
+    "sundial",
+    "sunglasses",
+    "sunglasses",
+    "sunscreen",
+    "suspension bridge",
+    "mop",
+    "sweatshirt",
+    "swim trunks / shorts",
+    "swing",
+    "electrical switch",
+    "syringe",
+    "table lamp",
+    "tank",
+    "tape player",
+    "teapot",
+    "teddy bear",
+    "television",
+    "tennis ball",
+    "thatched roof",
+    "front curtain",
+    "thimble",
+    "threshing machine",
+    "throne",
+    "tile roof",
+    "toaster",
+    "tobacco shop",
+    "toilet seat",
+    "torch",
+    "totem pole",
+    "tow truck",
+    "toy store",
+    "tractor",
+    "semi-trailer truck",
+    "tray",
+    "trench coat",
+    "tricycle",
+    "trimaran",
+    "tripod",
+    "triumphal arch",
+    "trolleybus",
+    "trombone",
+    "hot tub",
+    "turnstile",
+    "typewriter keyboard",
+    "umbrella",
+    "unicycle",
+    "upright piano",
+    "vacuum cleaner",
+    "vase",
+    "vaulted or arched ceiling",
+    "velvet fabric",
+    "vending machine",
+    "vestment",
+    "viaduct",
+    "violin",
+    "volleyball",
+    "waffle iron",
+    "wall clock",
+    "wallet",
+    "wardrobe",
+    "military aircraft",
+    "sink",
+    "washing machine",
+    "water bottle",
+    "water jug",
+    "water tower",
+    "whiskey jug",
+    "whistle",
+    "hair wig",
+    "window screen",
+    "window shade",
+    "Windsor tie",
+    "wine bottle",
+    "airplane wing",
+    "wok",
+    "wooden spoon",
+    "wool",
+    "split-rail fence",
+    "shipwreck",
+    "sailboat",
+    "yurt",
+    "website",
+    "comic book",
+    "crossword",
+    "traffic or street sign",
+    "traffic light",
+    "dust jacket",
+    "menu",
+    "plate",
+    "guacamole",
+    "consomme",
+    "hot pot",
+    "trifle",
+    "ice cream",
+    "popsicle",
+    "baguette",
+    "bagel",
+    "pretzel",
+    "cheeseburger",
+    "hot dog",
+    "mashed potatoes",
+    "cabbage",
+    "broccoli",
+    "cauliflower",
+    "zucchini",
+    "spaghetti squash",
+    "acorn squash",
+    "butternut squash",
+    "cucumber",
+    "artichoke",
+    "bell pepper",
+    "cardoon",
+    "mushroom",
+    "Granny Smith apple",
+    "strawberry",
+    "orange",
+    "lemon",
+    "fig",
+    "pineapple",
+    "banana",
+    "jackfruit",
+    "cherimoya (custard apple)",
+    "pomegranate",
+    "hay",
+    "carbonara",
+    "chocolate syrup",
+    "dough",
+    "meatloaf",
+    "pizza",
+    "pot pie",
+    "burrito",
+    "red wine",
+    "espresso",
+    "tea cup",
+    "eggnog",
+    "mountain",
+    "bubble",
+    "cliff",
+    "coral reef",
+    "geyser",
+    "lakeshore",
+    "promontory",
+    "sandbar",
+    "beach",
+    "valley",
+    "volcano",
+    "baseball player",
+    "bridegroom",
+    "scuba diver",
+    "rapeseed",
+    "daisy",
+    "yellow lady's slipper",
+    "corn",
+    "acorn",
+    "rose hip",
+    "horse chestnut seed",
+    "coral fungus",
+    "agaric",
+    "gyromitra",
+    "stinkhorn mushroom",
+    "earth star fungus",
+    "hen of the woods mushroom",
+    "bolete",
+    "corn cob",
+    "toilet paper",
+]
+# Maps numeric class ids to labels
+IMAGENET_1K_CLASS_ID_TO_LABEL = dict(
+    zip(range(len(openai_imagenet_classnames)), openai_imagenet_classnames)
+)
diff --git a/multimodal/open_flamingo/eval/ok_vqa_utils.py b/multimodal/open_flamingo/eval/ok_vqa_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2db61942fd0263213fe92b1025e906e59220380b
--- /dev/null
+++ b/multimodal/open_flamingo/eval/ok_vqa_utils.py
@@ -0,0 +1,213 @@
+# Those are manual mapping that are not caught by our stemming rules or would
+# would be done incorrectly by our automatic stemming rule. In details,
+# the keys of the _MANUAL_MATCHES dict contains the original word and the value
+# contains the transformation of the word expected by the OKVQA stemming rule.
+# These manual rules were found by checking the `raw_answers` and the `answers`
+# fields of the released OKVQA dataset and checking all things that were not
+# properly mapped by our automatic rules. In particular some of the mapping
+# are sometimes constant, e.g. christmas -> christmas which was incorrectly
+# singularized by our inflection.singularize.
+import re
+import nltk
+from nltk.corpus.reader import VERB
+import inflection
+
+_MANUAL_MATCHES = {
+    "police": "police",
+    "las": "las",
+    "vegas": "vegas",
+    "yes": "yes",
+    "jeans": "jean",
+    "hell's": "hell",
+    "domino's": "domino",
+    "morning": "morn",
+    "clothes": "cloth",
+    "are": "are",
+    "riding": "ride",
+    "leaves": "leaf",
+    "dangerous": "danger",
+    "clothing": "cloth",
+    "texting": "text",
+    "kiting": "kite",
+    "firefighters": "firefight",
+    "ties": "tie",
+    "married": "married",
+    "teething": "teeth",
+    "gloves": "glove",
+    "tennis": "tennis",
+    "dining": "dine",
+    "directions": "direct",
+    "waves": "wave",
+    "christmas": "christmas",
+    "drives": "drive",
+    "pudding": "pud",
+    "coding": "code",
+    "plating": "plate",
+    "quantas": "quanta",
+    "hornes": "horn",
+    "graves": "grave",
+    "mating": "mate",
+    "paned": "pane",
+    "alertness": "alert",
+    "sunbathing": "sunbath",
+    "tenning": "ten",
+    "wetness": "wet",
+    "urinating": "urine",
+    "sickness": "sick",
+    "braves": "brave",
+    "firefighting": "firefight",
+    "lenses": "lens",
+    "reflections": "reflect",
+    "backpackers": "backpack",
+    "eatting": "eat",
+    "designers": "design",
+    "curiousity": "curious",
+    "playfulness": "play",
+    "blindness": "blind",
+    "hawke": "hawk",
+    "tomatoe": "tomato",
+    "rodeoing": "rodeo",
+    "brightness": "bright",
+    "circuses": "circus",
+    "skateboarders": "skateboard",
+    "staring": "stare",
+    "electronics": "electron",
+    "electicity": "elect",
+    "mountainous": "mountain",
+    "socializing": "social",
+    "hamburgers": "hamburg",
+    "caves": "cave",
+    "transitions": "transit",
+    "wading": "wade",
+    "creame": "cream",
+    "toileting": "toilet",
+    "sautee": "saute",
+    "buildings": "build",
+    "belongings": "belong",
+    "stockings": "stock",
+    "walle": "wall",
+    "cumulis": "cumuli",
+    "travelers": "travel",
+    "conducter": "conduct",
+    "browsing": "brows",
+    "pooping": "poop",
+    "haircutting": "haircut",
+    "toppings": "top",
+    "hearding": "heard",
+    "sunblocker": "sunblock",
+    "bases": "base",
+    "markings": "mark",
+    "mopeds": "mope",
+    "kindergartener": "kindergarten",
+    "pies": "pie",
+    "scrapbooking": "scrapbook",
+    "couponing": "coupon",
+    "meetings": "meet",
+    "elevators": "elev",
+    "lowes": "low",
+    "men's": "men",
+    "childrens": "children",
+    "shelves": "shelve",
+    "paintings": "paint",
+    "raines": "rain",
+    "paring": "pare",
+    "expressions": "express",
+    "routes": "rout",
+    "pease": "peas",
+    "vastness": "vast",
+    "awning": "awn",
+    "boy's": "boy",
+    "drunkenness": "drunken",
+    "teasing": "teas",
+    "conferences": "confer",
+    "ripeness": "ripe",
+    "suspenders": "suspend",
+    "earnings": "earn",
+    "reporters": "report",
+    "kid's": "kid",
+    "containers": "contain",
+    "corgie": "corgi",
+    "porche": "porch",
+    "microwaves": "microwave",
+    "batter's": "batter",
+    "sadness": "sad",
+    "apartments": "apart",
+    "oxygenize": "oxygen",
+    "striping": "stripe",
+    "purring": "pure",
+    "professionals": "profession",
+    "piping": "pipe",
+    "farmer's": "farmer",
+    "potatoe": "potato",
+    "emirates": "emir",
+    "womens": "women",
+    "veteran's": "veteran",
+    "wilderness": "wilder",
+    "propellers": "propel",
+    "alpes": "alp",
+    "charioteering": "chariot",
+    "swining": "swine",
+    "illness": "ill",
+    "crepte": "crept",
+    "adhesives": "adhesive",
+    "regent's": "regent",
+    "decorations": "decor",
+    "rabbies": "rabbi",
+    "overseas": "oversea",
+    "travellers": "travel",
+    "casings": "case",
+    "smugness": "smug",
+    "doves": "dove",
+    "nationals": "nation",
+    "mustange": "mustang",
+    "ringe": "ring",
+    "gondoliere": "gondolier",
+    "vacationing": "vacate",
+    "reminders": "remind",
+    "baldness": "bald",
+    "settings": "set",
+    "glaced": "glace",
+    "coniferous": "conifer",
+    "revelations": "revel",
+    "personals": "person",
+    "daughter's": "daughter",
+    "badness": "bad",
+    "projections": "project",
+    "polarizing": "polar",
+    "vandalizers": "vandal",
+    "minerals": "miner",
+    "protesters": "protest",
+    "controllers": "control",
+    "weddings": "wed",
+    "sometimes": "sometime",
+    "earing": "ear",
+}
+
+
+class OKVQAStemmer:
+    """Stemmer to match OKVQA v1.1 procedure."""
+
+    def __init__(self):
+        self._wordnet_lemmatizer = nltk.stem.WordNetLemmatizer()
+
+    def stem(self, input_string):
+        """Apply stemming."""
+        word_and_pos = nltk.pos_tag(nltk.tokenize.word_tokenize(input_string))
+        stemmed_words = []
+        for w, p in word_and_pos:
+            if w in _MANUAL_MATCHES:
+                w = _MANUAL_MATCHES[w]
+            elif w.endswith("ing"):
+                w = self._wordnet_lemmatizer.lemmatize(w, VERB)
+            elif p.startswith("NNS") or p.startswith("NNPS"):
+                w = inflection.singularize(w)
+            stemmed_words.append(w)
+        return " ".join(stemmed_words)
+
+
+stemmer = OKVQAStemmer()
+
+
+def postprocess_ok_vqa_generation(prediction) -> str:
+    prediction_stem = stemmer.stem(prediction)
+    return prediction_stem
diff --git a/multimodal/open_flamingo/eval/task/__init__.py b/multimodal/open_flamingo/eval/task/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/multimodal/open_flamingo/eval/task/caption.py b/multimodal/open_flamingo/eval/task/caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..f70faa1b3c11248f3576e6c7c52e4f9ea2436691
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/caption.py
@@ -0,0 +1,284 @@
+from lavis.datasets.builders import load_dataset
+import torch
+import more_itertools
+from tqdm import tqdm
+from coco_metric import compute_cider, postprocess_captioning_generation
+import json
+import time
+import os
+from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor
+
+
+class VisualLogitsProcessor(LogitsProcessor):
+    def __init__(self, tokenizer):
+        super().__init__()
+        self.tokenizer = tokenizer
+        self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
+        self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+        self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+        self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+        self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+        self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1]
+        self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+        self.topk = 2
+
+    def __call__(self, input_ids, scores):
+        # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk]))
+        # import pdb; pdb.set_trace()
+        if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum():
+            scores[0, self.object_token_id] = 1000
+        if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id:
+            if (input_ids[0, :-1] == self.object_token_id).sum() != 0:
+                # print("generate a previsual token next")
+                scores[0, self.previsual_token_id] = 1000
+        elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id:
+            # print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual")
+            scores[0, self.eos_token_id] = 1000
+        elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id:
+            # print("generate a visual token next")
+            scores[0, self.visual_token_id] = 1000
+        return scores
+
+
+def prepare_batch_images(batch, image_processor):
+    batch_images = None
+    for b in batch:
+        b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        if batch_images is None:
+            batch_images = b_image
+        else:
+            batch_images = torch.cat([batch_images, b_image], dim=0)
+    return batch_images
+
+
+def evaluate_coco_flickr(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size,
+    is_flickr=False,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    debug=False,
+):
+    """Evaluate a model on COCO dataset.
+    Returns:
+        float: CIDEr score
+
+    """
+    visual_logits_processor = VisualLogitsProcessor(tokenizer)
+    coco_dataset = load_dataset("coco_caption")
+    eval_dataset = coco_dataset["test"]
+    model.eval().cuda()
+    predictions = dict()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token = "<|#box#|>"
+    prebox_token = "<|#prebox#|>"
+    endofobject_token = "<|#endofobject#|>"
+    object_token = "<|#object#|>"
+    cnt = 0
+    if world_size > 1:
+        torch.distributed.barrier()
+    desc = "Running inference Flickr30" if is_flickr else "Running inference COCO"
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc=desc, disable=(rank != 0)), batch_size
+    )):
+        if ii % world_size != rank:
+            continue
+        cnt += len(batch)
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        prompt = f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>"
+        added_bbox_list = []
+        batch_text = [prompt for _ in batch]
+        encodings = tokenizer(
+            batch_text,
+            padding="longest",
+            truncation=True,
+            return_tensors="pt",
+            max_length=2000,
+        )
+        ori_prompt_length = len(encodings["input_ids"][0])
+        have_prebox = False
+        while True:
+            batch_text = [prompt for _ in batch]
+            encodings = tokenizer(
+                batch_text,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=2000,
+            )
+            input_ids = encodings["input_ids"].cuda()
+            attention_mask = encodings["attention_mask"].cuda()
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            if debug:
+                print("input--->",tokenizer.decode(input_ids[0]))
+            p1 = MinNewTokensLengthLogitsProcessor(
+                prompt_length_to_skip=input_ids.shape[-1],
+                min_new_tokens=5,
+                eos_token_id=bos_token_id,
+            )
+            with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+                outputs = model.generate(
+                    batch_images,
+                    input_ids,
+                    attention_mask=attention_mask,
+                    max_new_tokens=20,
+                    # min_new_tokens=8,
+                    num_beams=1,
+                    # length_penalty=0,
+                    image_start_index_list=image_start_index_list,
+                    image_nums=image_nums,
+                    added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
+                    logits_processor_list=[p1, visual_logits_processor],
+                )
+            if debug:
+                print("outputs--->",tokenizer.decode(outputs[0]))
+            if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id:
+                prompt = tokenizer.decode(outputs.clone()[0])
+                is_visual = (outputs[0, -2] == visual_token_id)
+                batch_text = tokenizer.batch_decode(outputs[:, :-1])
+                encodings = tokenizer(
+                    batch_text,
+                    padding="longest",
+                    truncation=True,
+                    return_tensors="pt",
+                    max_length=2000,
+                )
+                input_ids = encodings["input_ids"].cuda()
+                attention_mask = encodings["attention_mask"].cuda()
+                image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+                image_start_index_list = [[x] for x in image_start_index_list]
+                image_nums = [1] * len(input_ids)
+                if debug:
+                    print("get the visual bbox--->",tokenizer.decode(input_ids[0]))
+                with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                    outputs = model(
+                        vision_x=batch_images,
+                        lang_x=input_ids,
+                        attention_mask=attention_mask,
+                        image_nums=image_nums,
+                        image_start_index_list=image_start_index_list,
+                        added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None,
+                        add_box=added_bbox_list is not None and len(added_bbox_list) != 0,
+                    )
+                boxes = outputs["boxes"]
+                scores = outputs["scores"]
+                # if not model.valid:
+                #     import pdb; pdb.set_trace()
+                if boxes is not None:
+                    if is_visual:
+                        if have_prebox:
+                            added_bbox_list.pop()
+                            prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "")
+                            have_prebox = False
+                            if debug:
+                                print("find previsual and remove it--->", prompt)
+                        first_box = boxes[scores.argmax()]
+                        added_bbox_list += [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
+                        prompt = prompt[:-len(tokenizer.eos_token)]
+                        prompt += box_token + endofobject_token
+                        if debug:
+                            print("after inserting visual---->", prompt)
+                    else:
+                        # import numpy as np
+                        # import cv2
+                        # open_cv_image = np.array(batch[0]["image"])
+                        # open_cv_image = open_cv_image[:, :, ::-1].copy()
+                        # for pre_box in boxes:
+                        #     open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
+                        # cv2.imwrite("Atest.png", open_cv_image)
+                        pre_box = boxes[scores.argmax()]
+                        added_bbox_list += [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
+                        prompt = prompt[:-len(tokenizer.eos_token)]
+                        prompt += prebox_token + object_token
+                        have_prebox = True
+                        if debug:
+                            print("after inserting previsual---->", prompt)
+                else:
+                    import pdb;pdb.set_trace()
+                    prompt = tokenizer.decode(outputs[0, :-2].clone()[0])
+            else:
+                break
+        outputs = outputs[:, ori_prompt_length:]
+        new_predictions = [
+            postprocess_captioning_generation(out).replace('"', "")
+            for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
+        ]
+        # import pdb; pdb.set_trace()
+        if rank == 0:
+            tqdm.write(new_predictions[0])
+        for i, sample in enumerate(batch):
+            predictions[int(sample["image_id"])] = {
+                "caption": new_predictions[i],
+            }
+    results_path = (
+        f"flickrresults_{lang_encoder_name}_{rank}_{id}.json"
+        if is_flickr
+        else f"cocoresults_{lang_encoder_name}_{rank}_{id}.json"
+    )
+    with open(results_path, "w") as f:
+        f.write(
+            json.dumps(
+                [
+                    {"image_id": k, "caption": predictions[k]["caption"]}
+                    for k in predictions
+                ],
+                indent=2,
+            )
+        )
+    print("save to", results_path)
+    del predictions
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            part_results_path = (
+                f"flickrresults_{lang_encoder_name}_{rank_i}_{id}.json"
+                if is_flickr
+                else f"cocoresults_{lang_encoder_name}_{rank_i}_{id}.json"
+            )
+            print("load", part_results_path)
+            predictions.extend(json.load(open(part_results_path)))
+            os.remove(part_results_path)
+        print("num:", len(predictions))
+        results_path = (
+            f"flickrresults_{lang_encoder_name}.json"
+            if is_flickr
+            else f"cocoresults_{lang_encoder_name}.json"
+        )
+        json.dump(predictions, open(results_path, "w"), indent=2)
+
+        metrics = compute_cider(
+            result_path=results_path,
+            annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json",
+        )
+        metrics["CIDEr"] *= 100
+        os.makedirs("eval_results", exist_ok=True)
+        acc = metrics["CIDEr"]
+        with open(os.path.join("eval_results", f"cococap_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(results_path)
+    else:
+        metrics = {}
+        metrics["CIDEr"] = 0.0
+
+    return metrics["CIDEr"]
diff --git a/multimodal/open_flamingo/eval/task/cola.py b/multimodal/open_flamingo/eval/task/cola.py
new file mode 100644
index 0000000000000000000000000000000000000000..08ad1e2f46a6460dcaa1be12624130b984cf31ab
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/cola.py
@@ -0,0 +1,220 @@
+import json
+import webdataset as wds
+from tqdm import tqdm
+from PIL import Image
+import torch
+import numpy as np
+import os
+import time
+import cv2
+import random
+import math
+from open_flamingo.eval.task.utils import (
+    get_object_from_text,
+    is_correct,
+    _eval_text_image,
+    get_bbox,
+    get_iou,
+)
+DATASET = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/COLA/data/COLA_multiobjects_matching_benchmark.json"
+VG_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/VG_100K"
+
+def get_score(image, text, model, tokenizer, image_processor, vis_embed_size):
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
+    text = text.split("#")
+    obj_A = text[0].strip().split(" ")
+    relation = text[1].strip()
+    obj_B = text[2].strip().split(" ")
+    if "computer mouse" not in text[0].strip():
+        attrAs = obj_A[:-1]
+        nounA = obj_A[-1]
+    else:
+        attrAs = obj_A[:-2]
+        nounA = " ".join(obj_A[-2:])
+    if "computer mouse" not in text[2].strip():
+        attrBs = obj_B[:-1]
+        nounB = obj_B[-1]
+    else:
+        attrBs = obj_B[:-2]
+        nounB = " ".join(obj_B[-2:])
+    # print("="*80)
+    # print(attrAs, nounA)
+    # print(attrBs, nounB)
+    # print(relation)
+    # print("="*80)
+    batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+
+
+    prompt1 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {nounA}<|#endofobject#|><|#visual#|>"]
+    boxes, scores = get_bbox(None, batch_images, prompt1, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
+
+
+    # open_cv_image = np.array(image)
+    # open_cv_image = open_cv_image[:, :, ::-1].copy()
+    # for pre_box in boxes:
+    #     open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
+
+    box_ppl = []
+    box_attr_losses = []
+    for box in boxes:
+        losses = []
+        for attrA in attrAs:
+            prompt2 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {attrA} {nounA}"]
+            encodings = tokenizer(
+                prompt2,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=512,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+            labels = lang_x.clone()
+            start_idx = (labels == object_token_id).nonzero()[-1, -1]
+            labels[0, :start_idx+1] = -100
+            added_bbox_list = [torch.tensor(box / 224.0).cuda().unsqueeze(0)]
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=labels,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=added_bbox_list,
+                    add_box=added_bbox_list is not None,
+                    relations=None,
+                )
+            loss = outputs.loss
+            loss = (loss.sum() / (loss != 0).sum()).item()
+            losses.append(loss)
+        avg_ppl = np.array(losses).mean()
+        box_ppl.append(avg_ppl)
+        box_attr_losses.append(losses)
+    fit_idx = np.array(box_ppl).argmin()
+    fit_box = boxes[fit_idx]
+    fit_attr = attrAs[np.array(box_attr_losses[fit_idx]).argmin()]
+    first_ppl = min(box_ppl)
+
+    # open_cv_image = cv2.rectangle(open_cv_image, fit_box[:2].astype(int), fit_box[2:].astype(int), (255, 0, 0), 2)
+
+
+    prompt3 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {fit_attr} {nounA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> is {relation}<|#object#|><|#previsual#|>"]
+    boxes, scores = get_bbox([torch.tensor(fit_box / 224).cuda().unsqueeze(0)], batch_images, prompt3, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
+    # for i, pre_box in enumerate(boxes):
+    #     open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 0, 255), i+1)
+    # cv2.imwrite(f"Atest.png", open_cv_image)
+
+    box_ppl = []
+    for box in boxes:
+        losses = []
+        for attrB in attrBs:
+            prompt4 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|>the {fit_attr} {nounA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> is {relation}<|#object#|><|#previsual#|><|#prebox#|><|#object#|> the {attrB} {nounB}"]
+            encodings = tokenizer(
+                prompt4,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=512,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+            labels = lang_x.clone()
+            start_idx = (labels == object_token_id).nonzero()[-1, -1]
+            labels[0, :start_idx+1] = -100
+            added_bbox_list = [torch.tensor(fit_box / 224.0).cuda().unsqueeze(0), torch.tensor(box / 224.0).cuda().unsqueeze(0)]
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    labels=labels,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=added_bbox_list,
+                    add_box=added_bbox_list is not None,
+                    relations=None,
+                )
+            loss = outputs.loss
+            loss = (loss.sum() / (loss != 0).sum()).item()
+            losses.append(loss)
+        avg_ppl = np.array(losses).mean()
+        box_ppl.append(avg_ppl)
+    second_ppl = (np.array(box_ppl) * np.array(scores)).sum() / sum(scores)
+    return (first_ppl + second_ppl) / 2
+
+
+def evaluate_cola(
+    model,
+    tokenizer,
+    image_processor,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    debug=False,
+):
+    dataset_name = "cola"
+    dataset = json.load(open(DATASET))
+    model = model.cuda().eval()
+    correct = 0
+    total = 0
+    pbar = tqdm(dataset, disable=(rank != 0))
+    for ii, sample in enumerate(pbar):
+        if ii % world_size != rank:
+            continue
+        image1 = Image.open(os.path.join(VG_ROOT, os.path.basename(sample[0]))).convert("RGB").resize((224, 224))
+        text1 = sample[1]
+        image2 = Image.open(os.path.join(VG_ROOT, os.path.basename(sample[2]))).convert("RGB").resize((224, 224))
+        text2 = sample[3]
+        score11 = -get_score(image1, text1, model, tokenizer, image_processor, vis_embed_size)
+        score12 = -get_score(image1, text2, model, tokenizer, image_processor, vis_embed_size)
+        score21 = -get_score(image2, text1, model, tokenizer, image_processor, vis_embed_size)
+        score22 = -get_score(image2, text2, model, tokenizer, image_processor, vis_embed_size)
+        if rank == 0:
+            tqdm.write(f"{score11:.2f} {score12:.2f} {score21:.2f} {score22:.2f}")
+        if score11 > score21 and score22 > score12:
+            correct += 1
+        total += 1
+        pbar.set_description(f"{correct / total:.2f}")
+    print(rank, correct / total)
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}_{total}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+if __name__ == "__main__":
+    evaluate_cola(None, None, None)
diff --git a/multimodal/open_flamingo/eval/task/crepe.py b/multimodal/open_flamingo/eval/task/crepe.py
new file mode 100644
index 0000000000000000000000000000000000000000..604db5ce8eb303c259bc83a43eca2c54423658d3
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/crepe.py
@@ -0,0 +1,93 @@
+import json
+import webdataset as wds
+from tqdm import tqdm
+from PIL import Image
+import torch
+import numpy as np
+import os
+import time
+import cv2
+import random
+import pandas as pd
+from .vl_checklist import _eval_text_image
+DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/instruct_data/crepe/prod_hard_negatives"
+
+
+def evaluate_crepe(
+    model,
+    tokenizer,
+    image_processor,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    subset=True,
+    debug=False,
+    level=4,
+    type="swap",
+):
+    if rank == 0:
+        tqdm.write(f"level: {level}")
+        tqdm.write(f"type: {type}")
+    dataset_name = "crepe"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    model.eval().cuda()
+    total = 0
+    correct = 0
+    assert type in ["swap"]
+    assert 4 <= level <= 12
+    filename = os.path.join(DATASET_ROOT, type, f"prod_vg_hard_negs_{type}_complexity_{level}.csv")
+    df = pd.read_csv(filename)
+    pbar = tqdm(df.iterrows(), disable=(rank != 0))
+    for ii, sample in pbar:
+        if ii % world_size != rank:
+            continue
+        text = sample.caption
+        image_path = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/VG_100K/{}.jpg".format(sample.image_id)
+        x = sample.x
+        y = sample.y
+        width = sample.width
+        height = sample.height
+        image = Image.open(image_path).convert("RGB")
+        image = image.crop((x, y, x+width, y+height))
+        image = image.resize((224, 224))
+        final_rank, final_ranks = _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=debug)
+        if final_rank is None:
+            continue
+        correct += int((np.array(final_ranks) < 10).sum())
+        total += len(final_ranks)
+        if debug:
+            tqdm.write("="*80)
+        pbar.set_description(f"{text} | score: {correct / total:.4f} | {final_rank} | {final_ranks}")
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, correct]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        correct = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, correct_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            correct += correct_part
+        score = correct / total
+        print("score:", score, "total:", total)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{score}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
diff --git a/multimodal/open_flamingo/eval/task/gqa.py b/multimodal/open_flamingo/eval/task/gqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..70effb840fea1d5bb1e1dfbd8222951210bdadd7
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/gqa.py
@@ -0,0 +1,248 @@
+from torch.utils.data import Dataset
+import json
+from PIL import Image
+import os
+import torch
+import more_itertools
+from tqdm import tqdm
+import time
+from vqa_metric import compute_gqa_accuracy
+import string
+import uuid
+import numpy as np
+import cv2
+from open_flamingo.eval.task.utils import get_bbox
+
+class GQADataset(Dataset):
+    def __init__(
+        self,
+        image_dir_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/gqa/images",
+        annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/gqa/testdev_balanced_questions.json",
+    ):
+        annotations = json.load(open(annotations_path))
+        self.questions = []
+        self.answers = []
+        self.image_paths = []
+        self.question_ids = []
+        for anno_id in annotations:
+            question = annotations[anno_id]["question"]
+            imageId = annotations[anno_id]["imageId"]
+            answer = annotations[anno_id]["answer"]
+            self.questions.append(question)
+            self.answers.append(answer)
+            self.image_paths.append(os.path.join(image_dir_path, "{}.jpg".format(imageId)))
+            self.question_ids.append(anno_id)
+            # print(annotations[anno_id]["types"])
+        self.vqa_dataset = "gqa"
+
+    def __len__(self):
+        return len(self.questions)
+
+    def __getitem__(self, idx):
+        question = self.questions[idx]
+        question_id = self.question_ids[idx]
+        answer = self.answers[idx]
+        img_path = self.image_paths[idx]
+        image = Image.open(img_path)
+        return {
+            "image": image,
+            "question": question,
+            "answers": answer,
+            "question_id": question_id,
+        }
+
+
+def prepare_batch_images(batch, image_processor):
+    batch_images = None
+    for b in batch:
+        b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        if batch_images is None:
+            batch_images = b_image
+        else:
+            batch_images = torch.cat([batch_images, b_image], dim=0)
+    return batch_images
+
+
+
+def evaluate_gqa(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size=1,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    """
+    Evaluate a model on VQA datasets. Currently supports VQA v2.0.
+
+    Args:
+        model (nn.Module): model to evaluate
+        tokenizer (transformers.PreTrainedTokenizer): tokenizer for the model
+        image_processor : image processor for the model
+        batch_size (int): batch size
+        image_dir_path (str): path to image directory
+        questions_json_path (str): path to questions json file
+        annotations_json_path (str): path to annotations json file
+        seed (int, optional): random seed. Defaults to 42.
+        max_generation_length (int, optional): max generation length. Defaults to 5.
+        num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
+        length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
+        num_samples (int, optional): number of samples to evaluate on. Defaults to 5000 samples.
+        query_set_size (int, optional): size of the query set. Defaults to 2048.
+        num_shots (int, optional): number of shots to use. Defaults to 8.
+        device (int, optional): device to use. Defaults to -1 (cpu).
+        num_workers (int, optional): number of workers to use. Defaults to 4.
+        vqa_dataset (string): type of vqa dataset: currently supports vqa, ok_vqa. Defaults to vqa.
+    Returns:
+        float: accuracy score
+    """
+    assert batch_size == 1
+    vqa_dataset = "gqa"
+    eval_dataset = GQADataset()
+    object_token_id = tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+    bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    def get_prompt(sample):
+        return f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:"
+    model.eval().cuda()
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    predictions = []
+    if batch_size != 1:
+        tokenizer.padding_side = "left"
+    if world_size > 1:
+        torch.distributed.barrier()
+    this_tot = 0
+    for ii, batch in enumerate(more_itertools.chunked(
+        tqdm(eval_dataset, desc="Running inference", disable=(rank != 0)), batch_size,
+    )):
+        if ii % world_size != rank:
+            continue
+        batch[0]["image"] = batch[0]["image"].resize((224, 224))
+        batch_images = prepare_batch_images(
+            batch=batch,
+            image_processor=image_processor,
+        ).cuda()
+        batch_text = [get_prompt(s) for s in batch]
+        encodings = tokenizer(
+            batch_text,
+            return_tensors="pt",
+            padding="longest",
+            truncation=True,
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"].cuda()
+        attention_mask = encodings["attention_mask"].cuda()
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+            outputs = model.generate(
+                batch_images,
+                input_ids,
+                attention_mask=attention_mask,
+                max_new_tokens=10,
+                min_length=1,
+                num_beams=1,
+                # length_penalty=0,
+                image_start_index_list=image_start_index_list,
+                image_nums=image_nums,
+                added_bbox_list=None,
+                return_dict_in_generate=True,
+                output_scores=True,
+            )
+        scores = outputs.scores
+        outputs = outputs.sequences[:, len(input_ids[0]) :]
+        if object_token_id in scores[0][0].sort(descending=True).indices[:5]:
+            sample = batch[0]
+            # print("="*80)
+            # print("sample:", batch, scores[0][0].sort(descending=True).indices[:10].tolist().index(object_token_id))
+            prompt1 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer:<|#object#|><|#previsual#|>"]
+            boxes, scores = get_bbox(None, batch_images, prompt1, model, tokenizer, media_token_id, prebox_token_id, return_all=True)
+            # open_cv_image = np.array(sample["image"])
+            # open_cv_image = open_cv_image[:, :, ::-1].copy()
+            # cv2.imwrite(f"Atest_ori.png", open_cv_image)
+            # open_cv_image = cv2.rectangle(open_cv_image, boxes[0][:2].astype(int), boxes[0][2:].astype(int), (0, 255, 0), 2)
+            # print(scores)
+            # cv2.imwrite(f"Atest.png", open_cv_image)
+            if boxes is not None and len(boxes) > 0:
+                prompt2 = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>Question: {sample['question'].strip()} Short answer: it is<|#object#|><|#previsual#|><|#prebox#|><|#object#|> a"]
+                encodings = tokenizer(
+                    prompt2,
+                    return_tensors="pt",
+                    padding="longest",
+                    truncation=True,
+                    max_length=2000,
+                )
+                input_ids = encodings["input_ids"].cuda()
+                attention_mask = encodings["attention_mask"].cuda()
+                image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+                image_start_index_list = [[x] for x in image_start_index_list]
+                image_nums = [1] * len(input_ids)
+                added_bbox_list = [torch.tensor(boxes[0]/224.0).cuda().unsqueeze(0).clamp(0, 0.99)]
+                with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+                    outputs = model.generate(
+                        batch_images,
+                        input_ids,
+                        attention_mask=attention_mask,
+                        max_new_tokens=10,
+                        min_length=1,
+                        num_beams=1,
+                        image_start_index_list=image_start_index_list,
+                        image_nums=image_nums,
+                        added_bbox_list=added_bbox_list,
+                        eos_token_id=(endofobject_token_id),
+                    )
+                outputs = outputs[:, len(input_ids[0]) :]
+                # print("previsual===>{}".format(tokenizer.decode(outputs[0], skip_special_tokens=True).strip().lower().strip(string.punctuation+" ")))
+
+        # postprocess begin
+        new_predictions = [
+            out.strip().lower().strip(string.punctuation+" ") for out in tokenizer.batch_decode(outputs, skip_special_tokens=True)
+        ]
+        this_tot += 1
+        predictions.extend(
+            [
+                {"answer": p, "question_id": sample["question_id"], "_question": sample["question"], "answers": sample["answers"]}
+                for p, sample in zip(new_predictions, batch)
+            ]
+        )
+    with open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps(predictions))
+    print("save to", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank}_{id}.json")
+
+    time.sleep(10)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            print("load", f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+            predictions.extend(json.load(open(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")))
+            os.remove(f"{vqa_dataset}_{lang_encoder_name}_results_part{rank_i}_{id}.json")
+        print("num:", len(predictions))
+        # save the predictions to a temporary file
+        random_uuid = str(uuid.uuid4())
+        with open(f"{vqa_dataset}results_{random_uuid}.json", "w") as f:
+            f.write(json.dumps(predictions, indent=4))
+
+        acc = compute_gqa_accuracy(predictions)
+        print(vqa_dataset, "score:", acc, "| save to", f"{vqa_dataset}results_{random_uuid}.json")
+        os.makedirs("eval_results", exist_ok=True)
+        with open(os.path.join("eval_results", f"{vqa_dataset}_{model.expr_name}_{model.step_num}_{int(time.time())}_{acc}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+
+        # delete the temporary file
+        os.remove(f"{vqa_dataset}results_{random_uuid}.json")
+    else:
+        time.sleep(5)
+        acc = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return acc
diff --git a/multimodal/open_flamingo/eval/task/mmbench.py b/multimodal/open_flamingo/eval/task/mmbench.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6cdba9ce2b79d20ab22d00034ecd3b03ac78f5
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/mmbench.py
@@ -0,0 +1,84 @@
+import base64
+import io
+import random
+
+import pandas as pd
+from PIL import Image
+from torch.utils.data import Dataset
+from open_flamingo.eval.task.utils import get_object_from_text
+
+def decode_base64_to_image(base64_string):
+    image_data = base64.b64decode(base64_string)
+    image = Image.open(io.BytesIO(image_data))
+    return image
+
+class MMBenchDataset(Dataset):
+    def __init__(self,
+                 data_file,
+                 sys_prompt='There are several options:'):
+        self.df = pd.read_csv(data_file, sep='\t')
+        self.sys_prompt = sys_prompt
+
+    def __len__(self):
+        return len(self.df)
+
+    def __getitem__(self, idx):
+        index = self.df.iloc[idx]['index']
+        image = self.df.iloc[idx]['image']
+        image = decode_base64_to_image(image)
+        question = self.df.iloc[idx]['question']
+        answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[0].keys() else None
+        catetory = self.df.iloc[idx]['category']
+        l2_catetory = self.df.iloc[idx]['l2-category']
+
+        option_candidate = ['A', 'B', 'C', 'D', 'E']
+        options = {
+            cand: self.load_from_df(idx, cand)
+            for cand in option_candidate
+            if self.load_from_df(idx, cand) is not None
+        }
+        options_prompt = f'{self.sys_prompt}\n'
+        for key, item in options.items():
+            options_prompt += f'{key}. {item}\n'
+
+        hint = self.load_from_df(idx, 'hint')
+        data = {
+            'img': image,
+            'question': question,
+            'answer': answer,
+            'options': options_prompt,
+            'category': catetory,
+            'l2-category': l2_catetory,
+            'options_dict': options,
+            'index': index,
+            'context': hint,
+        }
+        return data
+    def load_from_df(self, idx, key):
+        if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
+            return self.df.iloc[idx][key]
+        else:
+            return None
+
+
+def evaluate_mmbench(
+    model,
+    tokenizer,
+    image_processor,
+    batch_size=1,
+    image_dir_path=None,
+    questions_json_path=None,
+    annotations_json_path=None,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    dataset_name = "mmbench"
+    dataset = MMBenchDataset("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/mmbench/mmbench_dev_20230712.tsv")
+    for sample in dataset:
+        print(sample)
+
+
+if __name__ == '__main__':
+    evaluate_mmbench(None, None, None)
diff --git a/multimodal/open_flamingo/eval/task/others/generate_refcocog_reg_label.py b/multimodal/open_flamingo/eval/task/others/generate_refcocog_reg_label.py
new file mode 100644
index 0000000000000000000000000000000000000000..db4813da9321309da122b572267120a2cff4d278
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/others/generate_refcocog_reg_label.py
@@ -0,0 +1,71 @@
+import os
+import json
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+from io import BytesIO
+import base64
+import pickle
+
+REFCOCO_TSVFILE = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/refcocog/refcocog_val.tsv"
+
+def get_iou(box1, box2):
+    # box1 and box2 should be in the format [x1, y1, x2, y2]
+    intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
+                   max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
+    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+    union = area_box1 + area_box2 - intersection
+    iou = intersection / union if union > 0 else 0
+    return iou
+
+
+if __name__ == "__main__":
+    data = {}
+    uniq_id_to_text = {}
+    uniq_id_to_image = {}
+    uniq_id_to_image_id = {}
+    annotations = dict(
+        annotations=[],
+        images=[],
+    )
+    with open(REFCOCO_TSVFILE, "r") as f:
+        lines = f.readlines()
+        for ii, line in enumerate(tqdm(lines)):
+            uniq_id, image_id, text, region_coord, image = line.split("\t")
+            image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
+            uniq_id_to_text[uniq_id] = text
+            uniq_id_to_image[uniq_id] = image
+            gt_box = np.array(list(map(float, region_coord.split(","))))
+            if image_id not in data:
+                data[image_id] = {}
+            duplicate = False
+            for box in data[image_id]:
+                if get_iou(gt_box, box) > 0.999:
+                    duplicate = True
+                    data[image_id][box].append(uniq_id)
+                    break
+            if not duplicate:
+                data[image_id][tuple(gt_box.tolist())] = [uniq_id]
+
+    region_id = -1
+    for image_id in data:
+        for region in data[image_id]:
+            region_id += 1
+            annotations["images"].append({"id": region_id})
+            for uniq_id in data[image_id][region]:
+                annotations["annotations"].append(
+                    {
+                        "image_id": region_id,
+                        "caption": uniq_id_to_text[uniq_id],
+                        "id": uniq_id,
+                    }
+                )
+                uniq_id_to_image_id[uniq_id] = region_id
+    pickle.dump({
+        "data": data,
+        "uniq_id_to_text": uniq_id_to_text,
+        "uniq_id_to_image": uniq_id_to_image,
+        "uniq_id_to_image_id": uniq_id_to_image_id,
+    }, open("refcocog_reg_val_data.pkl", "wb"))
+    json.dump(annotations, open("refcocog_reg_val_label.json", "w"), indent=2)
diff --git a/multimodal/open_flamingo/eval/task/others/generate_vlc_subset.py b/multimodal/open_flamingo/eval/task/others/generate_vlc_subset.py
new file mode 100644
index 0000000000000000000000000000000000000000..59cf20da9e672cab6ab5f0e12e33056befe97cf7
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/others/generate_vlc_subset.py
@@ -0,0 +1,7 @@
+import json
+import sys
+import random
+size = int(sys.argv[-1])
+dataset = json.load(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/vlc_data.json"))
+subset = random.choices(dataset, k=size)
+json.dump(subset, open(f"vlc_data_subset_{size//1000}k.json", "w"), indent=1)
diff --git a/multimodal/open_flamingo/eval/task/reg.py b/multimodal/open_flamingo/eval/task/reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..61758cf8df154f69f830693c4c1b27513b744a82
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/reg.py
@@ -0,0 +1,141 @@
+import torch
+from tqdm import tqdm
+from PIL import Image
+from io import BytesIO
+import base64
+import numpy as np
+import time
+import json
+import os
+import cv2
+from coco_metric import compute_cider
+import random
+import pickle
+
+def evaluate_reg(
+    model,
+    tokenizer,
+    image_processor,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+):
+    lang_encoder_name = model.lang_encoder.__class__.__name__.lower()
+    dataset_name = "refcocog"
+    pkl_file = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_data.pkl"
+    try:
+        media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+        endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+        pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1]
+        bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1]
+    except:
+        pass
+
+    model.eval().cuda()
+    if world_size > 1:
+        torch.distributed.barrier()
+    this_tot = 0
+    predictions = []
+    D = pickle.load(open(pkl_file, "rb"))
+    lines = []
+    data = D["data"]
+    uniq_id_to_text = D["uniq_id_to_text"]
+    uniq_id_to_image = D["uniq_id_to_image"]
+    uniq_id_to_image_id = D["uniq_id_to_image_id"]
+    for image_id in data:
+        for region in data[image_id]:
+            uniq_id = data[image_id][region][0]
+            lines.append([uniq_id, uniq_id_to_image_id[uniq_id], [uniq_id_to_text[r] for r in data[image_id][region]], region, uniq_id_to_image[uniq_id]])
+    print("total data:", len(lines))
+    # lines = lines[:20]
+    pbar = tqdm(lines, disable=(rank != 0))
+    for ii, line in enumerate(pbar):
+        if ii % world_size != rank:
+            continue
+        uniq_id, image_id, text, region_coord, image = line
+        gt_box = np.array(region_coord)
+        width = image.width
+        height = image.height
+        image = image.resize((224, 224))
+        gt_box = gt_box / np.array([width, height, width, height]) * 224
+        batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#object#|><|#previsual#|><|#prebox#|><|#object#|>"]
+
+        encodings = tokenizer(
+            prompt,
+            padding="longest",
+            truncation=True,
+            return_tensors="pt",
+            max_length=2000,
+        )
+        input_ids = encodings["input_ids"]
+        attention_mask = encodings["attention_mask"]
+        image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+        image_start_index_list = [[x] for x in image_start_index_list]
+        image_nums = [1] * len(input_ids)
+        batch_images = batch_images.cuda()
+        input_ids = input_ids.cuda()
+        attention_mask = attention_mask.cuda()
+        added_bbox_list = [(torch.tensor(gt_box).cuda() / 224).clamp(0, 0.99).unsqueeze(0)]
+
+        with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+            outputs = model.generate(
+                batch_images,
+                input_ids,
+                attention_mask=attention_mask,
+                max_new_tokens=25,
+                min_length=5,
+                num_beams=8,
+                length_penalty=0,
+                image_start_index_list=image_start_index_list,
+                image_nums=image_nums,
+                added_bbox_list=added_bbox_list,
+            )
+        outputs = outputs[:, len(input_ids[0]) :]
+        new_prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip().lower()
+        this_tot += 1
+        if rank == 0 and this_tot % 10 == 0:
+            for i in range(1):
+                tqdm.write(f"answer: {text}\nmodel output: {new_prediction}")
+        predictions.append(
+            {"image_id": image_id, "caption": new_prediction}
+        )
+    results_path = f"reg_{lang_encoder_name}_{rank}_{id}.json"
+    json.dump(predictions, open(results_path, "w"))
+    print("save to", results_path)
+    del predictions
+    time.sleep(5)
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        predictions = []
+        for rank_i in range(world_size):
+            part_results_path = f"reg_{lang_encoder_name}_{rank_i}_{id}.json"
+            print("load", part_results_path)
+            part_data = json.load(open(part_results_path))
+            predictions.extend(part_data)
+            os.remove(part_results_path)
+        print("num:", len(predictions))
+        results_path = f"reg_{lang_encoder_name}_{id}_result.json"
+        json.dump(predictions, open(results_path, "w"), indent=2)
+
+        metrics = compute_cider(
+            result_path=results_path,
+            annotations_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/others/refcocog_reg_val_label.json",
+        )
+        os.makedirs("eval_results", exist_ok=True)
+        cider = metrics["CIDEr"]
+        print("cider", cider)
+        with open(os.path.join("eval_results", f"reg_{model.expr_name}_{model.step_num}_{int(time.time())}_{cider}"), "w") as f:
+            f.write(json.dumps(predictions, indent=2))
+        # delete the temporary file
+        os.remove(results_path)
+        return cider
+
+
+if __name__ == "__main__":
+    anno = json.load(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco_gt/coco_karpathy_test_gt.json"))
+    import pdb; pdb.set_trace()
+    print(anno.keys())
diff --git a/multimodal/open_flamingo/eval/task/utils.py b/multimodal/open_flamingo/eval/task/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8f1d90ba8ef797be43ef2b6e63b6e8746124504
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/utils.py
@@ -0,0 +1,287 @@
+import spacy
+import torch
+from tqdm import tqdm
+import numpy as np
+import itertools
+nlp = spacy.load('en_core_web_md')
+
+
+def get_iou(box1, box2):
+    # box1 and box2 should be in the format [x1, y1, x2, y2]
+    intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
+                   max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
+    area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
+    area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
+    union = area_box1 + area_box2 - intersection
+    iou = intersection / union if union > 0 else 0
+    return iou
+
+
+# def find_root(token):
+#     if token.pos_ == "VERB":
+#         return token
+#     while token.dep_ not in ["pobj", "nsubj", "ROOT", "npadvmod", "dobj", "det", "prep", "punct", "cc", "conj", "acl", "dep", "appos", "relcl", "advmod", "nmod", "attr"]:
+#         token = token.head
+#     return token
+
+
+def find_root(token):
+    if token.pos_ == "VERB":
+        return token
+    while token.dep_ in ["compound", "amod"]:
+        token = token.head
+    return token
+
+def get_object_from_text(text, verbose=False):
+    if len(text.split(" ")) == 3:
+        text = text.split(" ")
+        return [text[0], text[-1]]
+    doc = nlp(text)
+    if verbose:
+        for TT in doc:
+            print(TT.text, TT.pos_, TT.dep_, TT.head)
+    roots = set()
+    for i, token in enumerate(doc):
+        roots.add(find_root(token))
+    exprs = []
+    roots = sorted(list(roots), key=lambda token: token.idx)
+    first_nsubj = True
+    if verbose:
+        print(roots)
+    for root in roots:
+        if root.pos_ not in ["NOUN", "PROPN"]:
+            continue
+        if root.dep_ not in ["pobj", "nsubj"]:
+            continue
+        if not first_nsubj and root.dep_ in ["nsubj"]:
+            continue
+        exprs.append([])
+        for token in doc:
+            if find_root(token) == root:
+                exprs[-1].append(token.text)
+        exprs[-1] = " ".join(exprs[-1]).replace(" '", "'")
+        if exprs[-1] not in text:
+            if verbose:
+                print("not in text error:", exprs[-1], "#",text)
+            # for TT in doc:
+            #     print(TT.text, TT.pos_, TT.dep_, TT.head)
+            # import pdb; pdb.set_trace()
+            exprs.pop()
+        if first_nsubj and root.dep_ in ["nsubj"]:
+            first_nsubj = False
+    if len(exprs) <= 1:
+        if verbose:
+            print("not enough exprs error:", exprs, "#",text)
+        return []
+    return exprs
+
+def is_correct(input_ids, logits, tokenizer, object: str, topk=5, N=10):
+    answer_id = torch.tensor(tokenizer(f" {object}", add_special_tokens=False)["input_ids"]).to(input_ids.device)
+    answer_begin_idx = (input_ids == answer_id[0]).nonzero()
+    answer_idx = None
+    for (batch_idx, IDX) in answer_begin_idx:
+        try:
+            if (input_ids[batch_idx, IDX:IDX+len(answer_id)] == answer_id).all():
+                answer_idx = list(range(IDX-1, IDX+len(answer_id)-1))
+        except:
+            pass
+    if answer_idx is None:
+        return np.inf, False, False
+    res = logits[0, answer_idx].softmax(-1).sort(descending=True)
+    values = res.values
+    indices = res.indices
+    chosen_ids = list(itertools.product(*([list(range(N))]*len(answer_idx))))
+    probs = []
+    for ids in chosen_ids:
+        prob = 1.0
+        for i, id in enumerate(ids):
+            prob *= values[i, id]
+        probs.append((prob.item(), ids))
+    probs.sort(reverse=True)
+    answer_pos = tuple([id_array.tolist().index(idx) for id_array, idx in zip(indices, answer_id)])
+    ranking = [p[1] for p in probs]
+    # if len(answer_idx) > 1:
+    #     import pdb; pdb.set_trace()
+    try:
+        r = ranking.index(answer_pos)
+        return r, r < 1, r < 5
+    except:
+        return np.inf, False, False
+
+def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False):
+    assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str)
+    encodings = tokenizer(
+        prompt,
+        padding="longest",
+        truncation=True,
+        return_tensors="pt",
+        max_length=2000,
+    )
+    input_ids = encodings["input_ids"]
+    attention_mask = encodings["attention_mask"]
+    image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+    image_start_index_list = [[x] for x in image_start_index_list]
+    image_nums = [1] * len(input_ids)
+    vision_x = batch_images.cuda()
+    lang_x = input_ids.cuda()
+    attention_mask = attention_mask.cuda()
+
+    model.debug_id = 0
+    with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16):
+        outputs = model(
+            vision_x=vision_x,
+            lang_x=lang_x,
+            attention_mask=attention_mask,
+            labels=None,
+            image_nums=image_nums,
+            image_start_index_list=image_start_index_list,
+            added_bbox_list=visual_box_list,
+            add_box=visual_box_list is not None,
+            relations=None,
+            debug_mode=False,
+        )
+    boxes = outputs["boxes"]
+    scores = outputs["scores"]
+    if debug:
+        import pdb; pdb.set_trace()
+    if return_all:
+        return boxes, scores
+    if len(scores) == 0:
+        return None, None
+    else:
+        return boxes[scores.argmax()], scores.max()
+
+
+def _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=False, objects=None):
+    batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
+    if objects is None:
+        objects = get_object_from_text(text)
+    if len(objects) == 0:
+        return None, None, None
+    if debug:
+        tqdm.write(text)
+        tqdm.write(f"{objects}")
+    first_idx = text.find(objects[0])
+    if first_idx == 0:
+        first_text = f"<|#object#|>{objects[0]}<|#endofobject#|><|#visual#|>"
+    else:
+        first_text = text[:first_idx-1] + f"<|#object#|> {objects[0]}<|#endofobject#|><|#visual#|>"
+    
+    if debug:
+        tqdm.write(first_text)
+    prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"]
+    # import pdb; pdb.set_trace()
+    # print("do first get_bbox |", first_text)
+    first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
+    if not model.valid and debug:
+        import pdb; pdb.set_trace()
+    if first_box is not None:
+        added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224]
+        text = first_text + "<|#box#|><|#endofobject#|>" + text[first_idx+len(objects[0]):]
+    else:
+        added_bbox_list = []
+
+    final_ranks = []
+    is_top1_list = []
+    is_top5_list = []
+    for kk, object in enumerate(objects):
+        if kk == 0:
+            continue
+        idx = text.find(objects[0])
+        for t_i, temp in enumerate(objects[1:kk+1]):
+            # t_i is actually the previous one. This is not a bug
+            idx = text.find(temp, idx + len(objects[t_i]))
+            while idx+len(temp) != len(text) and (text[idx-1] == "#" or text[idx+len(temp)] == "#"):
+                # in case temp is box or object or visual or something like that
+                idx = text.find(temp, idx + len(temp))
+        this_text = text[:idx-1] + "<|#object#|><|#previsual#|>"
+        # if this_text == "<|#object#|><|#previsual#|>":
+        #     import pdb; pdb.set_trace()
+        if debug:
+            tqdm.write(this_text)
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
+        # import pdb; pdb.set_trace()
+        # print("do pre get_bbox |", this_text)
+        pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, 
+        prebox_token_id, return_all=True)
+        if not model.valid and debug:
+            import pdb; pdb.set_trace()
+        logits_list = []
+        # pre_boxes = [pre_boxes[0]]
+        # pre_scores = [pre_scores[0]]
+        this_text = this_text + f"<|#prebox#|><|#object#|> {object}<|#endofobject#|>"
+        for pre_box, pre_score in zip(pre_boxes, pre_scores):
+            prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
+            encodings = tokenizer(
+                prompt,
+                padding="longest",
+                truncation=True,
+                return_tensors="pt",
+                max_length=512,
+            )
+            input_ids = encodings["input_ids"]
+            attention_mask = encodings["attention_mask"]
+            image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist()
+            image_start_index_list = [[x] for x in image_start_index_list]
+            image_nums = [1] * len(input_ids)
+            vision_x = batch_images.cuda()
+            lang_x = input_ids.cuda()
+            attention_mask = attention_mask.cuda()
+            this_added_bbox_list = added_bbox_list + [torch.tensor(pre_box).unsqueeze(0).cuda() / 224]
+
+            with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad():
+                outputs = model(
+                    vision_x=vision_x,
+                    lang_x=lang_x,
+                    attention_mask=attention_mask,
+                    image_nums=image_nums,
+                    image_start_index_list=image_start_index_list,
+                    added_bbox_list=this_added_bbox_list,
+                    add_box=this_added_bbox_list is not None and len(this_added_bbox_list) != 0,
+                    relations=None,
+                )
+            if not model.valid and debug:
+                import pdb; pdb.set_trace()
+            logits_list.append([pre_score, outputs.logits])
+            if debug:
+                answer_start_idx = (lang_x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1]
+                logits = outputs["logits"][0, answer_start_idx:]
+                tqdm.write(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10]))
+            # if debug:
+            #     image.save("Atest.png")
+            #     open_cv_image = np.array(image)
+            #     open_cv_image = open_cv_image[:, :, ::-1].copy()
+            #     if first_box is not None:
+            #         open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2)
+            #     if pre_box is not None:
+            #         open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2)
+            #     cv2.imwrite(f"Atest.png", open_cv_image)
+            #     import pdb; pdb.set_trace()
+        pre_scores = np.array([x[0] for x in logits_list])
+        final_probs = 0.0
+        for score, (_, logits) in zip(pre_scores, logits_list):
+            final_probs += score * logits.softmax(-1)
+        assert input_ids.shape[:2] == final_probs.shape[:2]
+        _rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, object, topk=5)
+        final_ranks.append(_rank)
+        is_top1_list.append(is_top1)
+        is_top5_list.append(is_top5)
+        this_text = text[:idx-1] + f"<|#object#|> {object}<|#endofobject#|><|#visual#|>"
+        if debug:
+            tqdm.write(this_text)
+        prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"]
+        # print("do this get_bbox |", this_text)
+        this_box, this_score = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False)
+        if not model.valid and debug:
+            import pdb; pdb.set_trace()
+        if this_box is not None:
+            added_bbox_list += [torch.tensor(this_box).unsqueeze(0).cuda() / 224]
+            text = this_text + "<|#box#|><|#endofobject#|>" + text[idx+len(object):]
+    return final_ranks, is_top1_list, is_top5_list
+
+
+
+
+if __name__ == "__main__":
+    # print(get_object_from_text("there is a cookie. there is a bear. white orio cookie is next to the teddy bear. car runs on the traffic road. there is a tree.", verbose=False))
+    print(get_object_from_text("President speaks to an American at a business office",verbose=True))
diff --git a/multimodal/open_flamingo/eval/task/vl_checklist.py b/multimodal/open_flamingo/eval/task/vl_checklist.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f92ec252728cffbda227cb841de824fa60a3882
--- /dev/null
+++ b/multimodal/open_flamingo/eval/task/vl_checklist.py
@@ -0,0 +1,113 @@
+import json
+import webdataset as wds
+from tqdm import tqdm
+from PIL import Image
+import torch
+import numpy as np
+import os
+import time
+import cv2
+import random
+from open_flamingo.eval.task.utils import (
+    get_object_from_text,
+    is_correct,
+    _eval_text_image,
+)
+DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/cdl/instruct_data/instruct/vl_checklist/Relation/000000.tar"
+
+def evaluate_vlc(
+    model,
+    tokenizer,
+    image_processor,
+    vis_embed_size=None,
+    rank=0,
+    world_size=1,
+    id=0,
+    subset=True,
+    subset_size="5k",
+    debug=False,
+):
+    dataset_name = "vlc"
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+    endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+    prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+    model.eval().cuda()
+    total = 0
+    n_top1 = 0
+    n_top5 = 0
+    n_top10 = 0
+    filename = "/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/vlc_data.json" if not subset else f"/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/open_flamingo/eval/task/vlc_data_subset_{subset_size}.json"
+    dataset = json.load(open(filename))
+
+    pbar = tqdm(dataset, disable=(rank != 0))
+    for ii, sample in enumerate(pbar):
+        if ii % world_size != rank:
+            continue
+        text, image_path = sample
+        image = Image.open(image_path).convert("RGB")
+        image = image.resize((224, 224))
+        final_ranks, is_top1_list, is_top5_list = _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=debug)
+        if final_ranks is None:
+            continue
+        n_top1 += int(sum(is_top1_list))
+        n_top5 += int(sum(is_top5_list))
+        n_top10 += int((np.array(final_ranks) < 10).sum())
+        total += len(final_ranks)
+        if debug:
+            tqdm.write("="*80)
+        pbar.set_description(f"acc@top1: {n_top1 / total:.4f} | acc@top5: {n_top5 / total:.4f} | acc@top10: {n_top10 / total:.4f} | {final_ranks} |{text}")
+
+
+    with open(f"{dataset_name}_results_part{rank}_{id}.json", "w") as f:
+        f.write(json.dumps([total, n_top1, n_top5, n_top10]))
+    if world_size > 1:
+        torch.distributed.barrier()
+    if rank == 0:
+        total = 0
+        n_top1 = 0
+        n_top5 = 0
+        n_top10 = 0
+        print(f"evaluate on rank {rank}. world size is {world_size}")
+        for rank_i in range(world_size):
+            [total_part, n_top1_part, n_top5_part, n_top10_part] = json.load(open(f"{dataset_name}_results_part{rank_i}_{id}.json"))
+            os.remove(f"{dataset_name}_results_part{rank_i}_{id}.json")
+            total += total_part
+            n_top1 += n_top1_part
+            n_top5 += n_top5_part
+            n_top10 += n_top10_part
+        print("acc@top1:", n_top1 / total, "acc@top5:", n_top5 / total, "acc@top10:", n_top10 / total, "total:", total)
+        with open(os.path.join("eval_results", f"{dataset_name}_{model.expr_name}_{model.step_num}_{int(time.time())}_{n_top1 / total}_{n_top5 / total}_{n_top10 / total}_{total}"), "w") as f:
+            pass
+    else:
+        score = 0.0
+    if world_size > 1:
+        torch.distributed.barrier()
+    return score
+
+
+if __name__ == "__main__":
+    dataset = wds.WebDataset(DATASET_ROOT).decode().shuffle(100000).to_tuple("data.pyd", "dataset.txt", "image_path.txt")
+    labels = set()
+    texts = []
+    data_pair = []
+    if not os.path.exists("vlc_data.json"):
+        for sample in tqdm(dataset):
+            data, dataset_name, image_path = sample
+            text = data[-1]["POS"][0]
+            texts.append(text)
+            data_pair.append([text, image_path])
+        json.dump(data_pair, open("vlc_data.json", "w"), indent=1)
+    else:
+        print("data exists")
+        data_pair = json.load(open("vlc_data.json"))
+        for text, image_path in data_pair:
+            texts.append(text)
+
+
+    
+    print(get_object_from_text("crow attacks the dove"))
diff --git a/multimodal/open_flamingo/eval/vqa_metric.py b/multimodal/open_flamingo/eval/vqa_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..e887b5c3839c5b7cfb1857356bb48c344014e28b
--- /dev/null
+++ b/multimodal/open_flamingo/eval/vqa_metric.py
@@ -0,0 +1,594 @@
+import copy
+import datetime
+import json
+import os
+import random
+import re
+import sys
+
+# Interface for accessing the VQA dataset.
+
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
+
+# The following functions are defined:
+#  VQA        - VQA class that loads VQA annotation file and prepares data structures.
+#  getQuesIds - Get question ids that satisfy given filter conditions.
+#  getImgIds  - Get image ids that satisfy given filter conditions.
+#  loadQA     - Load questions and answers with the specified question ids.
+#  showQA     - Display the specified questions and answers.
+#  loadRes    - Load result file and create result object.
+
+# Help on each function can be accessed by: "help(COCO.function)"
+
+
+class VQA:
+    def __init__(self, annotation_file=None, question_file=None):
+        """
+        Constructor of VQA helper class for reading and visualizing questions and answers.
+        :param annotation_file (str): location of VQA annotation file
+        :return:
+        """
+        # load dataset
+        self.dataset = {}
+        self.questions = {}
+        self.qa = {}
+        self.qqa = {}
+        self.imgToQA = {}
+        if not annotation_file == None and not question_file == None:
+            print("loading VQA annotations and questions into memory...")
+            time_t = datetime.datetime.utcnow()
+            dataset = json.load(open(annotation_file, "r"))
+            questions = json.load(open(question_file, "r"))
+            print(datetime.datetime.utcnow() - time_t)
+            self.dataset = dataset
+            self.questions = questions
+            self.createIndex()
+
+    def createIndex(self):
+        # create index
+        print("creating index...")
+        imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
+        qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
+        qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
+        for ann in self.dataset["annotations"]:
+            imgToQA[ann["image_id"]] += [ann]
+            qa[ann["question_id"]] = ann
+        for ques in self.questions["questions"]:
+            qqa[ques["question_id"]] = ques
+        print("index created!")
+
+        # create class members
+        self.qa = qa
+        self.qqa = qqa
+        self.imgToQA = imgToQA
+
+    def info(self):
+        """
+        Print information about the VQA annotation file.
+        :return:
+        """
+        for key, value in self.dataset["info"].items():
+            print("%s: %s" % (key, value))
+
+    def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
+        """
+        Get question ids that satisfy given filter conditions. default skips that filter
+        :param 	imgIds    (int array)   : get question ids for given imgs
+                        quesTypes (str array)   : get question ids for given question types
+                        ansTypes  (str array)   : get question ids for given answer types
+        :return:    ids   (int array)   : integer array of question ids
+        """
+        imgIds = imgIds if type(imgIds) == list else [imgIds]
+        quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+        ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+        if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
+            anns = self.dataset["annotations"]
+        else:
+            if not len(imgIds) == 0:
+                anns = sum(
+                    [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
+                    [],
+                )
+            else:
+                anns = self.dataset["annotations"]
+            anns = (
+                anns
+                if len(quesTypes) == 0
+                else [ann for ann in anns if ann["question_type"] in quesTypes]
+            )
+            anns = (
+                anns
+                if len(ansTypes) == 0
+                else [ann for ann in anns if ann["answer_type"] in ansTypes]
+            )
+        ids = [ann["question_id"] for ann in anns]
+        return ids
+
+    def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
+        """
+         Get image ids that satisfy given filter conditions. default skips that filter
+         :param quesIds   (int array)   : get image ids for given question ids
+        quesTypes (str array)   : get image ids for given question types
+        ansTypes  (str array)   : get image ids for given answer types
+         :return: ids     (int array)   : integer array of image ids
+        """
+        quesIds = quesIds if type(quesIds) == list else [quesIds]
+        quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+        ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+        if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
+            anns = self.dataset["annotations"]
+        else:
+            if not len(quesIds) == 0:
+                anns = sum(
+                    [self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
+                )
+            else:
+                anns = self.dataset["annotations"]
+            anns = (
+                anns
+                if len(quesTypes) == 0
+                else [ann for ann in anns if ann["question_type"] in quesTypes]
+            )
+            anns = (
+                anns
+                if len(ansTypes) == 0
+                else [ann for ann in anns if ann["answer_type"] in ansTypes]
+            )
+        ids = [ann["image_id"] for ann in anns]
+        return ids
+
+    def loadQA(self, ids=[]):
+        """
+        Load questions and answers with the specified question ids.
+        :param ids (int array)       : integer ids specifying question ids
+        :return: qa (object array)   : loaded qa objects
+        """
+        if type(ids) == list:
+            return [self.qa[id] for id in ids]
+        elif type(ids) == int:
+            return [self.qa[ids]]
+
+    def showQA(self, anns):
+        """
+        Display the specified annotations.
+        :param anns (array of object): annotations to display
+        :return: None
+        """
+        if len(anns) == 0:
+            return 0
+        for ann in anns:
+            quesId = ann["question_id"]
+            print("Question: %s" % (self.qqa[quesId]["question"]))
+            for ans in ann["answers"]:
+                print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
+
+    def loadRes(self, resFile, quesFile):
+        """
+        Load result file and return a result object.
+        :param   resFile (str)     : file name of result file
+        :return: res (obj)         : result api object
+        """
+        res = VQA()
+        res.questions = json.load(open(quesFile))
+        res.dataset["info"] = copy.deepcopy(self.questions["info"])
+        res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
+        res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
+        res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
+        res.dataset["license"] = copy.deepcopy(self.questions["license"])
+
+        print("Loading and preparing results...     ")
+        time_t = datetime.datetime.utcnow()
+        anns = json.load(open(resFile))
+        assert type(anns) == list, "results is not an array of objects"
+        annsQuesIds = [ann["question_id"] for ann in anns]
+        # print set of question ids that do not have corresponding annotations
+
+        # assert set(annsQuesIds) == set(self.getQuesIds()), \
+        # 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
+        for ann in anns:
+            quesId = ann["question_id"]
+            if res.dataset["task_type"] == "Multiple Choice":
+                assert (
+                    ann["answer"] in self.qqa[quesId]["multiple_choices"]
+                ), "predicted answer is not one of the multiple choices"
+            qaAnn = self.qa[quesId]
+            ann["image_id"] = qaAnn["image_id"]
+            ann["question_type"] = qaAnn["question_type"]
+            ann["answer_type"] = qaAnn["answer_type"]
+        print(
+            "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
+        )
+
+        res.dataset["annotations"] = anns
+        res.createIndex()
+        return res
+
+
+class VQAEval:
+    def __init__(self, vqa=None, vqaRes=None, n=2):
+        self.n = n
+        self.accuracy = {}
+        self.evalQA = {}
+        self.evalQuesType = {}
+        self.evalAnsType = {}
+        self.vqa = vqa
+        self.vqaRes = vqaRes
+        if vqaRes is not None:
+            self.params = {"question_id": vqaRes.getQuesIds()}
+        self.contractions = {
+            "aint": "ain't",
+            "arent": "aren't",
+            "cant": "can't",
+            "couldve": "could've",
+            "couldnt": "couldn't",
+            "couldn'tve": "couldn't've",
+            "couldnt've": "couldn't've",
+            "didnt": "didn't",
+            "doesnt": "doesn't",
+            "dont": "don't",
+            "hadnt": "hadn't",
+            "hadnt've": "hadn't've",
+            "hadn'tve": "hadn't've",
+            "hasnt": "hasn't",
+            "havent": "haven't",
+            "hed": "he'd",
+            "hed've": "he'd've",
+            "he'dve": "he'd've",
+            "hes": "he's",
+            "howd": "how'd",
+            "howll": "how'll",
+            "hows": "how's",
+            "Id've": "I'd've",
+            "I'dve": "I'd've",
+            "Im": "I'm",
+            "Ive": "I've",
+            "isnt": "isn't",
+            "itd": "it'd",
+            "itd've": "it'd've",
+            "it'dve": "it'd've",
+            "itll": "it'll",
+            "let's": "let's",
+            "maam": "ma'am",
+            "mightnt": "mightn't",
+            "mightnt've": "mightn't've",
+            "mightn'tve": "mightn't've",
+            "mightve": "might've",
+            "mustnt": "mustn't",
+            "mustve": "must've",
+            "neednt": "needn't",
+            "notve": "not've",
+            "oclock": "o'clock",
+            "oughtnt": "oughtn't",
+            "ow's'at": "'ow's'at",
+            "'ows'at": "'ow's'at",
+            "'ow'sat": "'ow's'at",
+            "shant": "shan't",
+            "shed've": "she'd've",
+            "she'dve": "she'd've",
+            "she's": "she's",
+            "shouldve": "should've",
+            "shouldnt": "shouldn't",
+            "shouldnt've": "shouldn't've",
+            "shouldn'tve": "shouldn't've",
+            "somebody'd": "somebodyd",
+            "somebodyd've": "somebody'd've",
+            "somebody'dve": "somebody'd've",
+            "somebodyll": "somebody'll",
+            "somebodys": "somebody's",
+            "someoned": "someone'd",
+            "someoned've": "someone'd've",
+            "someone'dve": "someone'd've",
+            "someonell": "someone'll",
+            "someones": "someone's",
+            "somethingd": "something'd",
+            "somethingd've": "something'd've",
+            "something'dve": "something'd've",
+            "somethingll": "something'll",
+            "thats": "that's",
+            "thered": "there'd",
+            "thered've": "there'd've",
+            "there'dve": "there'd've",
+            "therere": "there're",
+            "theres": "there's",
+            "theyd": "they'd",
+            "theyd've": "they'd've",
+            "they'dve": "they'd've",
+            "theyll": "they'll",
+            "theyre": "they're",
+            "theyve": "they've",
+            "twas": "'twas",
+            "wasnt": "wasn't",
+            "wed've": "we'd've",
+            "we'dve": "we'd've",
+            "weve": "we've",
+            "werent": "weren't",
+            "whatll": "what'll",
+            "whatre": "what're",
+            "whats": "what's",
+            "whatve": "what've",
+            "whens": "when's",
+            "whered": "where'd",
+            "wheres": "where's",
+            "whereve": "where've",
+            "whod": "who'd",
+            "whod've": "who'd've",
+            "who'dve": "who'd've",
+            "wholl": "who'll",
+            "whos": "who's",
+            "whove": "who've",
+            "whyll": "why'll",
+            "whyre": "why're",
+            "whys": "why's",
+            "wont": "won't",
+            "wouldve": "would've",
+            "wouldnt": "wouldn't",
+            "wouldnt've": "wouldn't've",
+            "wouldn'tve": "wouldn't've",
+            "yall": "y'all",
+            "yall'll": "y'all'll",
+            "y'allll": "y'all'll",
+            "yall'd've": "y'all'd've",
+            "y'alld've": "y'all'd've",
+            "y'all'dve": "y'all'd've",
+            "youd": "you'd",
+            "youd've": "you'd've",
+            "you'dve": "you'd've",
+            "youll": "you'll",
+            "youre": "you're",
+            "youve": "you've",
+        }
+        self.manualMap = {
+            "none": "0",
+            "zero": "0",
+            "one": "1",
+            "two": "2",
+            "three": "3",
+            "four": "4",
+            "five": "5",
+            "six": "6",
+            "seven": "7",
+            "eight": "8",
+            "nine": "9",
+            "ten": "10",
+        }
+        self.articles = ["a", "an", "the"]
+
+        self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
+        self.commaStrip = re.compile("(\d)(\,)(\d)")
+        self.punct = [
+            ";",
+            r"/",
+            "[",
+            "]",
+            '"',
+            "{",
+            "}",
+            "(",
+            ")",
+            "=",
+            "+",
+            "\\",
+            "_",
+            "-",
+            ">",
+            "<",
+            "@",
+            "`",
+            ",",
+            "?",
+            "!",
+        ]
+
+    def evaluate(self, quesIds=None):
+        if quesIds == None:
+            quesIds = [quesId for quesId in self.params["question_id"]]
+        gts = {}
+        res = {}
+        for quesId in quesIds:
+            gts[quesId] = self.vqa.qa[quesId]
+            res[quesId] = self.vqaRes.qa[quesId]
+
+        # =================================================
+        # Compute accuracy
+        # =================================================
+        accQA = []
+        accQuesType = {}
+        accAnsType = {}
+        print("computing accuracy")
+        step = 0
+        for quesId in quesIds:
+            for ansDic in gts[quesId]["answers"]:
+                ansDic["answer"] = ansDic["answer"].replace("\n", " ")
+                ansDic["answer"] = ansDic["answer"].replace("\t", " ")
+                ansDic["answer"] = ansDic["answer"].strip()
+            resAns = res[quesId]["answer"]
+            resAns = resAns.replace("\n", " ")
+            resAns = resAns.replace("\t", " ")
+            resAns = resAns.strip()
+            gtAcc = []
+            gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
+
+            if len(set(gtAnswers)) > 1:
+                for ansDic in gts[quesId]["answers"]:
+                    ansDic["answer"] = self.processPunctuation(ansDic["answer"])
+                    ansDic["answer"] = self.processDigitArticle(ansDic["answer"])
+                resAns = self.processPunctuation(resAns)
+                resAns = self.processDigitArticle(resAns)
+
+            for gtAnsDatum in gts[quesId]["answers"]:
+                otherGTAns = [
+                    item for item in gts[quesId]["answers"] if item != gtAnsDatum
+                ]
+                matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
+                acc = min(1, float(len(matchingAns)) / 3)
+                gtAcc.append(acc)
+            quesType = gts[quesId]["question_type"]
+            ansType = gts[quesId]["answer_type"]
+            avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
+            accQA.append(avgGTAcc)
+            if quesType not in accQuesType:
+                accQuesType[quesType] = []
+            accQuesType[quesType].append(avgGTAcc)
+            if ansType not in accAnsType:
+                accAnsType[ansType] = []
+            accAnsType[ansType].append(avgGTAcc)
+            self.setEvalQA(quesId, avgGTAcc)
+            self.setEvalQuesType(quesId, quesType, avgGTAcc)
+            self.setEvalAnsType(quesId, ansType, avgGTAcc)
+            if step % 100 == 0:
+                self.updateProgress(step / float(len(quesIds)))
+            step = step + 1
+
+        self.setAccuracy(accQA, accQuesType, accAnsType)
+        print("Done computing accuracy")
+
+    def processPunctuation(self, inText):
+        outText = inText
+        for p in self.punct:
+            if (p + " " in inText or " " + p in inText) or (
+                re.search(self.commaStrip, inText) != None
+            ):
+                outText = outText.replace(p, "")
+            else:
+                outText = outText.replace(p, " ")
+        outText = self.periodStrip.sub("", outText, re.UNICODE)
+        return outText
+
+    def processDigitArticle(self, inText):
+        outText = []
+        tempText = inText.lower().split()
+        for word in tempText:
+            word = self.manualMap.setdefault(word, word)
+            if word not in self.articles:
+                outText.append(word)
+            else:
+                pass
+        for wordId, word in enumerate(outText):
+            if word in self.contractions:
+                outText[wordId] = self.contractions[word]
+        outText = " ".join(outText)
+        return outText
+
+    def setAccuracy(self, accQA, accQuesType, accAnsType):
+        self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
+        self.accuracy["perQuestionType"] = {
+            quesType: round(
+                100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
+                self.n,
+            )
+            for quesType in accQuesType
+        }
+        self.accuracy["perAnswerType"] = {
+            ansType: round(
+                100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
+            )
+            for ansType in accAnsType
+        }
+
+    def setEvalQA(self, quesId, acc):
+        self.evalQA[quesId] = round(100 * acc, self.n)
+
+    def setEvalQuesType(self, quesId, quesType, acc):
+        if quesType not in self.evalQuesType:
+            self.evalQuesType[quesType] = {}
+        self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
+
+    def setEvalAnsType(self, quesId, ansType, acc):
+        if ansType not in self.evalAnsType:
+            self.evalAnsType[ansType] = {}
+        self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
+
+    def updateProgress(self, progress):
+        barLength = 20
+        status = ""
+        if isinstance(progress, int):
+            progress = float(progress)
+        if not isinstance(progress, float):
+            progress = 0
+            status = "error: progress var must be float\r\n"
+        if progress < 0:
+            progress = 0
+            status = "Halt...\r\n"
+        if progress >= 1:
+            progress = 1
+            status = "Done...\r\n"
+        block = int(round(barLength * progress))
+        text = "\rFinshed Percent: [{0}] {1}% {2}".format(
+            "#" * block + "-" * (barLength - block), int(progress * 100), status
+        )
+        sys.stdout.write(text)
+        sys.stdout.flush()
+
+
+def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_path, vqa_dataset):
+    """Compute the VQA accuracy metric.
+
+    Args:
+        predictions (List): list of predictions
+        ground_truth (List[List]): list of all possible ground truth answers
+
+    Returns:
+        float: VQA accuracy
+    """
+    # coding: utf-8
+    # dataDir = data_dir
+
+    # set up file names and paths
+    # versionType = 'v2_'  # this should be '' when using VQA v2.0 dataset
+    # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
+    # taskType = 'OpenEnded'
+    # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
+    # dataType = 'mscoco'
+    # dataSubType = 'train2014'
+    # annFile = '%s/%s%s_%s_annotations.json' % (
+    # dataDir, versionType, dataType, dataSubType)
+    # quesFile = '%s/%s%s_%s_%s_questions.json' % (
+    # dataDir, versionType, taskType, dataType, dataSubType)
+    # imgDir = '%s/%s/%s/' % (dataDir, dataType, dataSubType)
+    # resultType = res_file_name
+    # fileTypes = ['results', 'accuracy',
+    #              'evalQA', 'evalQuesType', 'evalAnsType']
+
+    # An example result json file has been provided in './Results' folder.
+
+    # [resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/%s%s_%s_%s_%s_%s.json' % (dataDir, versionType, taskType, dataType, dataSubType,
+    # resultType, fileType) for fileType in fileTypes]
+
+    # create vqa object and vqaRes object
+    vqa = VQA(annotation_json_path, question_json_path)
+    vqaRes = vqa.loadRes(result_json_path, question_json_path)
+
+    # create vqaEval object by taking vqa and vqaRes
+    # n is precision of accuracy (number of places after decimal), default is 2
+    vqaEval = VQAEval(vqa, vqaRes, n=2)
+
+    # evaluate results
+    """
+    If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
+    By default it uses all the question ids in annotation file
+    """
+    vqaEval.evaluate()
+
+    return vqaEval.accuracy["overall"]
+
+
+def postprocess_vqa_generation(predictions):
+    return re.split("Question|Answer", predictions, 1)[0]
+
+
+def compute_gqa_accuracy(results):
+    acc = []
+    vqa_tool = VQAEval()
+
+    for res in results:
+        gt_ans = res["answers"]
+        pred = res["answer"]
+        pred = vqa_tool.processPunctuation(pred)
+        pred = vqa_tool.processDigitArticle(pred)
+        vqa_acc = 1 if pred == gt_ans else 0
+        acc.append(vqa_acc)
+    accuracy = sum(acc) / len(acc)
+    return accuracy
diff --git a/multimodal/open_flamingo/src/__init__.py b/multimodal/open_flamingo/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/multimodal/open_flamingo/src/attention.py b/multimodal/open_flamingo/src/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..90b9d286f1d1bf1768a085b265d3db39c783eced
--- /dev/null
+++ b/multimodal/open_flamingo/src/attention.py
@@ -0,0 +1,45 @@
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import init
+
+
+
+class SEAttention(nn.Module):
+
+    def __init__(self, channel=512,reduction=16):
+        super().__init__()
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction, bias=False),
+            nn.GELU(),
+            nn.Linear(channel // reduction, channel, bias=False),
+            nn.GELU(),
+            nn.Linear(channel, 1, bias=False),
+            nn.Sigmoid()
+        )
+
+
+    def init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                init.kaiming_normal_(m.weight, mode='fan_out')
+                if m.bias is not None:
+                    init.constant_(m.bias, 0)
+            elif isinstance(m, nn.BatchNorm2d):
+                init.constant_(m.weight, 1)
+                init.constant_(m.bias, 0)
+            elif isinstance(m, nn.Linear):
+                init.normal_(m.weight, std=0.001)
+                if m.bias is not None:
+                    init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        x = self.fc(x)
+        return x
+
+
+if __name__ == '__main__':
+    input=torch.randn(50,512,7,7)
+    se = SEAttention(channel=512,reduction=8)
+    output=se(input)
+    print(output.shape)
diff --git a/multimodal/open_flamingo/src/factory.py b/multimodal/open_flamingo/src/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..925b09ed463c15bc895117dbdc1344586274846c
--- /dev/null
+++ b/multimodal/open_flamingo/src/factory.py
@@ -0,0 +1,269 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import open_clip
+import torch
+
+from .flamingo import Flamingo
+from .flamingo_lm import FlamingoLMMixin
+from .utils import extend_instance
+import logging
+import random
+import time
+
+def create_model_and_transforms(
+    clip_vision_encoder_path: str,
+    clip_vision_encoder_pretrained: str,
+    lang_encoder_path: str,
+    tokenizer_path: str,
+    use_local_files: bool = False,
+    decoder_layers_attr_name: str = None,
+    location_token_num: int = 1000,
+    checkpoint_activations: bool = False,
+    freeze_vision_encoder: bool = False,
+    lora: bool = False,
+    lora_r: int = 16,
+    fix_ffn: bool = False,
+    add_visual_token: bool = False,
+    add_box: bool = False,
+    add_pe: bool = False,
+    add_relation: bool = False,
+    use_format_v2: bool = False,
+    use_sam: str = None,
+    enhance_data: bool = False,
+    roi_align: bool = False,
+    roi_output_size: int = 4,
+    apply_mask: bool = False,
+    **flamingo_kwargs,
+):
+    """
+    Initialize a Flamingo model from a pretrained vision encoder and language encoder.
+    Appends special tokens to the tokenizer and freezes backbones.
+
+    Args:
+        clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
+        clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
+        lang_encoder_path (str): path to pretrained language encoder
+        tokenizer_path (str): path to pretrained tokenizer
+        cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
+        use_local_files (bool, optional): whether to use local files. Defaults to False.
+        decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
+    Returns:
+        Flamingo: Flamingo model from pretrained vision and language encoders
+        Image processor: Pipeline to preprocess input images
+        Tokenizer: A tokenizer for the language model
+    """
+    if use_sam is None:
+        no_success = True
+        while no_success:
+            try:
+                vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
+                    clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained
+                )
+                no_success = False
+            except:
+                logging.info("retry creating vision_encoder")
+                time.sleep(random.random() * 5)
+
+        # set the vision encoder to output the visual features
+        vision_encoder.visual.output_tokens = True
+        # delete text encoder part
+        del vision_encoder.transformer
+        del vision_encoder.text_projection
+        del vision_encoder.token_embedding
+        del vision_encoder.ln_final
+        del vision_encoder.positional_embedding
+        del vision_encoder.logit_scale
+        vision_encoder.visual.proj = None
+        vision_encoder.visual.ln_post = torch.nn.Identity()
+    else:
+        from segment_anything import SamPredictor, sam_model_registry
+        assert use_sam == "vit_l"
+        sam = sam_model_registry[use_sam](checkpoint="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/checkpoint/sam_vit_l_0b3195_256x256.pth")
+        del sam.prompt_encoder
+        del sam.mask_decoder
+        sam.image_encoder.neck = torch.nn.Identity()
+        vision_encoder = sam.image_encoder
+        from open_clip.transform import image_transform
+        image_processor = image_transform(
+            256,
+            is_train=False,
+            mean=(0.48145466, 0.4578275, 0.40821073),
+            std=(0.26862954, 0.26130258, 0.27577711),
+        )
+
+    text_tokenizer = AutoTokenizer.from_pretrained(
+        tokenizer_path, local_files_only=use_local_files
+    )
+    # add Flamingo special tokens to the tokenizer
+    additional_special_tokens = ["<|#image#|>", "<|#endofimage#|>"]
+    if add_visual_token:
+        additional_special_tokens += ["<|#visual#|>", "<|#object#|>"]
+    if add_box:
+        additional_special_tokens += ["<|#box#|>", "<|#endofobject#|>", "<|#attr#|>", "<|#endofattr#|>"]
+    if use_format_v2:
+        additional_special_tokens += ["<|#previsual#|>", "<|#prebox#|>"]
+    if enhance_data:
+        additional_special_tokens += ["<|#NOTHING#|>"]
+    text_tokenizer.add_special_tokens(
+        {"additional_special_tokens": additional_special_tokens}
+    )
+    if text_tokenizer.pad_token is None:
+        # Issue: GPT models don't have a pad token, which we use to
+        # modify labels for the loss.
+        text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
+
+    lang_encoder = AutoModelForCausalLM.from_pretrained(
+        lang_encoder_path, local_files_only=use_local_files
+    )
+    extend_instance(lang_encoder, FlamingoLMMixin)
+
+    if decoder_layers_attr_name is None:
+        decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
+    lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
+    lang_encoder.resize_token_embeddings(len(text_tokenizer))
+    lang_encoder_name = lang_encoder.__class__.__name__.lower()
+    if checkpoint_activations:
+        from fairscale.nn.checkpoint import checkpoint_wrapper
+        if use_sam is None:
+            for i in range(len(vision_encoder.visual.transformer.resblocks)):
+                vision_encoder.visual.transformer.resblocks[i] = checkpoint_wrapper(
+                    vision_encoder.visual.transformer.resblocks[i],
+                    offload_to_cpu=False,
+                )
+        else:
+            for i in range(len(vision_encoder.blocks)):
+                vision_encoder.blocks[i] = checkpoint_wrapper(
+                    vision_encoder.blocks[i],
+                    offload_to_cpu=False,
+                )
+        if "opt" in lang_encoder_name:
+            for i in range(len(lang_encoder.model.decoder.layers)):
+                lang_encoder.model.decoder.layers[i] = checkpoint_wrapper(
+                    lang_encoder.model.decoder.layers[i],
+                    offload_to_cpu=False,
+                )
+        elif "codegen" in lang_encoder_name:
+            for i in range(len(lang_encoder.transformer.h)):
+                lang_encoder.transformer.h[i] = checkpoint_wrapper(
+                    lang_encoder.transformer.h[i],
+                    offload_to_cpu=False,
+                )
+        elif "llama" in lang_encoder_name:
+            for i in range(len(lang_encoder.model.layers)):
+                lang_encoder.model.layers[i] = checkpoint_wrapper(
+                    lang_encoder.model.layers[i],
+                    offload_to_cpu=False,
+                )
+        elif "gptneo" in lang_encoder_name:
+            for i in range(len(lang_encoder.gpt_neox.layers)):
+                lang_encoder.gpt_neox.layers[i] = checkpoint_wrapper(
+                    lang_encoder.gpt_neox.layers[i],
+                    offload_to_cpu=False,
+                )
+        else:
+            raise ValueError(f"unknown model {lang_encoder_name}")
+    if use_sam is None:
+        vis_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"]
+        image_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["image_size"]
+        patch_size = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["patch_size"]
+    else:
+        # SAM config
+        vis_dim = 1024
+        image_size = 256
+        patch_size = 16
+    assert image_size % patch_size == 0
+    vis_embed_size = (image_size // patch_size) ** 2
+
+    if lora:
+        from peft import LoraConfig, TaskType
+        from peft import get_peft_model
+        if "codegen" in lang_encoder_name:
+            lang_target_modules = ["qkv_proj", "out_proj", "fc_in", "fc_out"]
+        elif "opt" in lang_encoder_name:
+            lang_target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
+        elif "llama" in lang_encoder_name:
+            lang_target_modules = ["k_proj", "v_proj", "q_proj", "o_proj", "gate_proj", "down_proj", "up_proj"]
+        else:
+            raise NotImplementedError
+        lang_peft_config = LoraConfig(
+            task_type="CAUSAL_LM",
+            r=16, lora_alpha=16,
+            target_modules=lang_target_modules,
+            lora_dropout=0.05, bias="none",
+        )
+        lang_encoder = get_peft_model(lang_encoder, lang_peft_config)
+        lang_encoder.print_trainable_parameters()
+
+    if fix_ffn:
+        if "opt" in lang_encoder_name:
+            for i in range(len(lang_encoder.model.decoder.layers)):
+                lang_encoder.model.decoder.layers[i].requires_grad_(False)
+                lang_encoder.model.decoder.layers[i].self_attn.requires_grad_(True)
+        else:
+            raise NotImplementedError
+
+    lang_dim = int(lang_encoder.config.hidden_size) if not lora else int(lang_encoder.base_model.model.config.hidden_size)
+    if hasattr(lang_encoder.config, "word_embed_proj_dim"):
+        hidden_state_dim = lang_encoder.config.word_embed_proj_dim
+    else:
+        hidden_state_dim = lang_encoder.config.hidden_size
+    model = Flamingo(
+        vision_encoder=vision_encoder,
+        lang_encoder=lang_encoder,
+        eoc_token_id=text_tokenizer.encode(text_tokenizer.eos_token)[-1],
+        media_token_id=text_tokenizer.encode("<|#image#|>")[-1],
+        image_end_token_id=text_tokenizer.encode("<|#endofimage#|>")[-1],
+        visual_token_id=text_tokenizer.encode("<|#visual#|>")[-1] if add_visual_token else None,
+        previsual_token_id=text_tokenizer.encode("<|#previsual#|>")[-1] if add_visual_token else None,
+        box_token_id=text_tokenizer.encode("<|#box#|>")[-1] if add_box else None,
+        prebox_token_id=text_tokenizer.encode("<|#prebox#|>")[-1] if add_box else None,
+        nothing_token_id=text_tokenizer.encode("<|#NOTHING#|>")[-1] if enhance_data else None,
+        endofobject_token_id=text_tokenizer.encode("<|#endofobject#|>")[-1],
+        vis_dim=vis_dim,
+        vis_embed_size=vis_embed_size,
+        lang_dim=lang_dim,
+        image_size=image_size,
+        patch_size=patch_size,
+        hidden_state_dim=hidden_state_dim,
+        add_visual_token=add_visual_token,
+        add_pe=add_pe,
+        add_relation=add_relation,
+        use_format_v2=use_format_v2,
+        roi_align=roi_align,
+        roi_output_size=roi_output_size,
+        apply_mask=apply_mask,
+        **flamingo_kwargs,
+    )
+
+    if freeze_vision_encoder:
+        print("freeze vision encoder")
+        model.vision_encoder.requires_grad_(False)
+
+    print(
+        f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters"
+    )
+
+    return model, image_processor, text_tokenizer, vis_embed_size
+
+
+def _infer_decoder_layers_attr_name(model):
+    for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
+        if k.lower() in model.__class__.__name__.lower():
+            return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
+
+    raise ValueError(
+        f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
+    )
+
+
+__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
+    "opt": "model.decoder.layers",
+    # "gptneo": "transformer.h",
+    "gptj": "transformer.h",
+    "gpt-j": "transformer.h",
+    "pythia": "gpt_neox.layers",
+    "gptneox": "gpt_neox.layers",
+    "llama": "model.layers",
+    "llamaforcausallm": "model.layers",
+    "gpt2": "transformer.h",
+    "codegen": "transformer.h",
+}
diff --git a/multimodal/open_flamingo/src/flamingo.py b/multimodal/open_flamingo/src/flamingo.py
new file mode 100644
index 0000000000000000000000000000000000000000..da54f8f02a046fad7dfcfe32fb59092b24d2f9da
--- /dev/null
+++ b/multimodal/open_flamingo/src/flamingo.py
@@ -0,0 +1,637 @@
+import torch
+import torchvision
+from einops import rearrange
+from torch import nn
+from yolox.models.yolo_head import YOLOXHead
+from yolox.utils.boxes import xyxy2cxcywh, cxcywh2xyxy
+from yolox.utils.demo_utils import nms
+# import matplotlib.pyplot as plt
+# import seaborn as sns
+import numpy as np
+import logging
+from open_flamingo.src.gcn import GCN
+from transformers import LogitsProcessorList
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s %(message)s',
+    datefmt='%m/%d %I:%M:%S',
+)
+
+
+# class PositionEncodingModule(nn.Module):
+#     def __init__(self, dim, pos_dim=128):
+#         super().__init__()
+#         self.encode = nn.Sequential(
+#             nn.Linear(5, pos_dim // 2),
+#             nn.BatchNorm1d(pos_dim // 2),
+#             nn.GELU(),
+#             nn.Linear(pos_dim // 2, pos_dim),
+#             nn.BatchNorm1d(pos_dim),
+#             nn.GELU(),
+#         )
+#         self.merge = nn.Sequential(
+#             nn.Linear(dim + pos_dim, dim),
+#             nn.BatchNorm1d(dim),
+#             nn.GELU(),
+#         )
+
+#     def forward(self, x, box):
+#         box = self.encode(box)
+#         x = torch.cat([x, box], dim=-1)
+#         x = self.merge(x)
+#         return x
+
+
+# class PositionEncodingModule(nn.Module):
+#     def __init__(self, dim):
+#         super().__init__()
+#         self.encode = nn.Sequential(
+#             nn.Linear(5, dim),
+#             nn.GELU(),
+#         )
+
+#     def forward(self, x, box):
+#         box = self.encode(box)
+#         x = x + box
+#         return x
+
+
+# class PositionEncodingModule2(nn.Module):
+#     def __init__(self, dim):
+#         super().__init__()
+#         self.encode = nn.Sequential(
+#             nn.Linear(5 + dim, dim),
+#             nn.ELU(),
+#         )
+
+#     def forward(self, x, box):
+#         x = torch.cat([x, box], dim=-1)
+#         x = self.encode(x)
+#         return x
+
+
+# class RelationHead(nn.Module):
+#     def __init__(self, dim):
+#         super().__init__()
+#         self.encode = nn.Sequential(
+#             nn.LayerNorm(dim),
+#             nn.Linear(dim, 128),
+#             nn.ELU(),
+#         )
+#         self.classifier = nn.Linear(256, 51)
+
+#     def forward(self, x1, x2):
+#         x1 = self.encode(x1)
+#         x2 = self.encode(x2)
+#         x = torch.cat([x1, x2], dim=-1)
+#         x = self.classifier(x)
+#         return x
+
+
+class Flamingo(nn.Module):
+    def __init__(
+        self,
+        vision_encoder: nn.Module,
+        lang_encoder: nn.Module,
+        eoc_token_id: int,
+        media_token_id: int,
+        image_end_token_id: int,
+        visual_token_id: int,
+        previsual_token_id: int,
+        box_token_id: int,
+        prebox_token_id: int,
+        nothing_token_id: int,
+        endofobject_token_id: int,
+        vis_dim: int,
+        vis_embed_size: int,
+        lang_dim: int,
+        hidden_state_dim: int,
+        image_size: int,
+        patch_size: int,
+        use_media_placement_augmentation: bool = False,
+        add_visual_token: bool = False,
+        add_pe: bool = False,
+        add_relation: bool = False,
+        use_format_v2: bool = False,
+        roi_align: bool = False,
+        roi_output_size: int = 4,
+        apply_mask: bool = False,
+    ):
+        """
+        Args:
+            vision_encoder (nn.Module): HF CLIPModel
+            lang_encoder (nn.Module): HF causal language model
+            eoc_token_id (int): Token id for eos token
+            media_token_id (int): Token id for <|#image#|>
+            vis_dim (int): Dimension of the visual features.
+                Visual features are projected to match this shape along the last dimension.
+            cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
+            use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False.
+        """
+        super().__init__()
+        self.image_end_token_id = image_end_token_id
+        self.eoc_token_id = eoc_token_id
+        self.media_token_id = media_token_id
+        self.use_media_placement_augmentation = use_media_placement_augmentation
+        self.vis_dim = vis_dim
+        self.lang_dim = lang_dim
+        # inner_dim = self.lang_dim * 4
+        # self.vis_proj = nn.Sequential(
+        #     nn.LayerNorm(self.vis_dim),
+        #     nn.Linear(self.vis_dim, inner_dim, bias=False),
+        #     nn.GELU(),
+        #     nn.Linear(inner_dim, self.lang_dim, bias=False),
+        # )
+        self.vis_proj = nn.Linear(self.vis_dim, self.lang_dim)
+        self.vision_encoder = vision_encoder
+        self.num_positions = vis_embed_size
+        self.lang_encoder = lang_encoder
+        self.lang_encoder.init_flamingo(
+            media_token_id=media_token_id,
+            use_media_placement_augmentation=self.use_media_placement_augmentation,
+        )
+        first_layer = self.lang_encoder._get_decoder_layers()[0]
+        first_layer.add_visual_token = add_visual_token
+        first_layer.visual_token_id = visual_token_id
+        first_layer.media_token_id = media_token_id
+        first_layer.box_token_id = box_token_id
+        # first_layer.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None
+        # assert not (add_pe and add_relation)
+        # self.pos_enc = PositionEncodingModule(self.lang_dim) if add_pe else None
+        # first_layer.pos_enc = self.pos_enc
+        self.box_token_id = box_token_id
+        self.prebox_token_id = prebox_token_id
+        self.media_token_id = media_token_id
+        self.visual_token_id = visual_token_id
+        self.previsual_token_id = previsual_token_id
+        self.hidden_state_dim = hidden_state_dim
+        self.image_size = image_size
+        self.patch_size = patch_size
+        self.patch_num = self.image_size // self.patch_size
+        self.detection_head = YOLOXHead(
+            num_classes=1,
+            strides=[patch_size],
+            in_channels=[self.hidden_state_dim + self.lang_dim],
+        )
+        self.use_format_v2 = use_format_v2
+        self.nothing_token_id = nothing_token_id
+        self.roi_align = roi_align
+        self.roi_output_size = roi_output_size if roi_align else None
+        self.apply_mask = apply_mask
+        self.endofobject_token_id = endofobject_token_id
+
+
+    def _get_detection_batch(
+        self,
+        visual_token_id,
+        previsual_token_id,
+        input_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        added_bbox_list,
+        box_num = 100,
+    ):
+        select_mask = torch.logical_or(input_ids == visual_token_id, input_ids == previsual_token_id)
+        visual_token_position = select_mask.nonzero()
+        visual_token_hidden_states = hidden_states[select_mask]
+        prev_batch_idx = -1
+        media_idx = []
+        cnt = 0
+        assert len(visual_token_hidden_states) == len(visual_token_position)
+        if len(added_bbox_list) != len(visual_token_position):
+            msg = f"ERROR: {len(added_bbox_list)}:{len(visual_token_position)}\n{added_bbox_list}\n{visual_token_position}"
+            logging.info(msg)
+            alpha = 0.0
+        else:
+            alpha = 1.0
+        visual_batches = []
+        previsual_batches = []
+        for (batch_idx, idx), visual_token_hidden_state, bbox in zip(
+            visual_token_position, visual_token_hidden_states, added_bbox_list,
+        ):
+            # ! VERY IMPORTANT BUG !
+            bbox = bbox.clone()
+            # ! VERY IMPORTANT BUG !
+            batch_idx = batch_idx.item()
+            idx = idx.item()
+            if batch_idx != prev_batch_idx:
+                prev_batch_idx = batch_idx
+                this_input_ids = input_ids[batch_idx]
+                cnt += len(media_idx)
+                media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
+            for i in range(len(media_idx)):
+                if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
+                    break
+            image_index = cnt + i
+            size = int(self.image_embedding[image_index].shape[0] ** 0.5)
+            image_embedding = self.image_embedding[image_index]
+            # inplace xyxy2cxcywh
+            # print(bbox)
+            # TODO: CHECK self.image_size. Is it 224?
+            bbox = xyxy2cxcywh(bbox) * self.image_size
+            # print(bbox)
+            concat_image_visual_embedding = torch.cat([image_embedding, visual_token_hidden_state.unsqueeze(0).repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1)
+            label = torch.cat([torch.zeros(bbox.shape[0], 1, device=bbox.device), bbox], dim=-1)
+            label = torch.cat([label, torch.zeros(box_num - label.shape[0], label.shape[1], device=label.device)], dim=0)
+            if input_ids[batch_idx, idx] == previsual_token_id:
+                previsual_batches.append([concat_image_visual_embedding, label])
+            elif input_ids[batch_idx, idx] == visual_token_id:
+                visual_batches.append([concat_image_visual_embedding, label])
+            else:
+                logging.info(f"WARNING... NOT visual nor previsual. it is {input_ids[batch_idx, idx]}")
+        return visual_batches, previsual_batches, alpha, alpha
+
+    def get_detection_losses(
+        self,
+        input_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        added_bbox_list,
+        box_num = 100,
+    ):
+        visual_token_batches, previsual_token_batches, alpha1, alpha2 = self._get_detection_batch(
+            visual_token_id=self.visual_token_id,
+            previsual_token_id=self.previsual_token_id,
+            input_ids=input_ids,
+            hidden_states=hidden_states,
+            added_bbox_list=added_bbox_list,
+            box_num=box_num,
+        )
+        loss_dict = []
+        for batches, alpha in zip([visual_token_batches, previsual_token_batches], [alpha1, alpha2]):
+            # x: [B, C, H, W]
+            if len(batches) != 0:
+                x = torch.cat([batch[0].unsqueeze(0) for batch in batches], dim=0).permute(0,3,1,2)
+                labels = torch.cat([batch[1].unsqueeze(0) for batch in batches], dim=0)
+            else:
+                x = None
+                labels = None
+            if x is not None:
+                losses = self.detection_head(xin=[x], labels=labels)
+                loss, loss_iou, loss_obj, loss_cls, loss_l1, _ = losses
+            else:
+                loss = torch.tensor(0.0).cuda()
+                loss_iou = loss
+                loss_obj = loss
+                loss_cls = loss
+                loss_l1 = loss
+
+            loss_dict.append(dict(
+                loss=loss * alpha,
+                loss_iou=loss_iou * alpha,
+                loss_obj=loss_obj * alpha,
+                loss_cls=loss_cls * alpha,
+                loss_l1=loss_l1 * alpha,
+            ))
+        ret_loss = {}
+        for key in loss_dict[0].keys():
+            ret_loss[key] = 0.0
+            for d in loss_dict:
+                ret_loss[key] += d[key]
+        return ret_loss, loss_dict
+
+    def get_detection_result(
+        self,
+        input_ids: torch.Tensor,
+        hidden_states: torch.Tensor,
+        nms_thr: float = 0.45,
+        score_thr: float = 0.01,
+        debug_id: int = 0,
+        debug_mode: bool = False,
+    ):
+        assert len(input_ids) == 1, "only batch size = 1 is supported yet"
+        # assert len(self.image_embedding) == 1, "only one image is supported yet" 
+        # assert (input_ids[..., -1] == self.visual_token_id).all(), "the last token should be visual token"
+        visual_token_hidden_state = hidden_states[..., -1, :]
+        boxes_list = []
+        scores_list = []
+        for image_embedding in self.image_embedding:
+            size = int(image_embedding.shape[0] ** 0.5)
+            x = torch.cat([image_embedding, visual_token_hidden_state.repeat(image_embedding.shape[0], 1)], dim=-1).reshape(size, size, -1).unsqueeze(0).permute(0,3,1,2)
+            with torch.no_grad():
+                outputs = self.detection_head(xin=[x], labels=None)
+            boxes = outputs[0,:,:4].cpu().numpy()
+            scores = outputs[0,:,4].cpu().numpy()
+            scores_mask = scores > score_thr
+            boxes = boxes[scores_mask]
+            boxes = cxcywh2xyxy(boxes)
+            scores = scores[scores_mask]
+            keep = nms(boxes, scores, nms_thr=nms_thr)
+            boxes = boxes[keep]
+            scores = scores[keep]
+            if debug_mode:
+                obj_heatmap = outputs[0,:, -2].reshape(size, size).cpu().numpy()
+                import matplotlib.pyplot as plt
+                import seaborn as sns
+                plt.figure()
+                sns_plot = sns.heatmap(obj_heatmap)
+                plt.savefig(f"heatmap_{debug_id}.jpg")
+                debug_id += 1
+            boxes_list.append(boxes)
+            scores_list.append(scores)
+        if len(boxes_list) == 1:
+            boxes_list = boxes_list[0]
+            scores_list = scores_list[0]
+        return boxes_list, scores_list
+
+    def _condition_attention(self, loc_list = None):
+        for i in range(len(self.lang_encoder.gpt_neox.layers)):
+            self.lang_encoder.gpt_neox.layers[i].decoder_layer.attention.loc_list = loc_list
+
+    def forward(
+        self,
+        vision_x: torch.Tensor,
+        lang_x: torch.Tensor,
+        attention_mask: torch.Tensor = None,
+        labels: torch.Tensor = None,
+        use_cached_vision_x: bool = False,
+        clear_conditioned_layers: bool = True,
+        past_key_values=None,
+        use_cache: bool = False,
+        image_nums=None,
+        image_start_index_list=None,
+        added_bbox_list=None,
+        add_box: bool = False,
+        relations=None,
+        debug_mode: bool = False,
+    ):
+        """
+        Forward pass of Flamingo.
+
+        Args:
+            vision_x (torch.Tensor): Vision input
+                shape (B, T_img, F, C, H, W) with F=1
+            lang_x (torch.Tensor): Language input ids
+                shape (B, T_txt)
+            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
+            labels (torch.Tensor, optional): Labels. Defaults to None.
+            clear_conditioned_layers: if True, clear the conditioned layers
+                once the foward pass is completed. Set this to false if the
+                same set of images will be reused in another subsequent
+                forward pass.
+            past_key_values: pre-computed values to pass to language model.
+                See past_key_values documentation in Hugging Face
+                CausalLM models.
+            use_cache: whether to use cached key values. See use_cache
+                documentation in Hugging Face CausalLM models.
+        """
+        self.valid = True
+        self.lang_encoder.loc_list = None
+        if use_cached_vision_x:
+            # Case: use cached; vision_x should be cached and other
+            # vision-related inputs should not be provided.
+            assert (
+                vision_x is None
+            ), "Expect vision_x to be None when use_cached_vision_x is True."
+            assert self.lang_encoder.is_conditioned()
+        else:
+            # Case: do not use caching (i.e. this is a standard forward pass);
+            self._encode_vision_x(
+                vision_x=vision_x,
+                image_nums=image_nums,
+                image_start_index_list=image_start_index_list,
+                added_bbox_list=added_bbox_list if add_box else None,
+                input_ids=lang_x,
+                relations=relations,
+            )
+        if self.apply_mask:
+            if self.roi_align:
+                attend_length = 1 + self.roi_output_size ** 2
+            else:
+                attend_length = 2
+            prebox_loc = (lang_x == self.prebox_token_id).nonzero()
+            loc_list = []
+            for (x, y) in prebox_loc:
+                x = x.item()
+                y = y.item()
+                for yy in range(y+1, lang_x.shape[1]):
+                    if lang_x[x, yy] == self.endofobject_token_id:
+                        # [batch_idx, [previsual:prebox], [object:endofobject-1]]
+                        loc_list.append([x, [y-attend_length+1, y], [y+1, yy-1]])
+            self._condition_attention(loc_list=loc_list)
+        else:
+            self._condition_attention(None)
+
+        output = self.lang_encoder(
+            input_ids=lang_x,
+            attention_mask=attention_mask,
+            labels=labels,
+            past_key_values=past_key_values,
+            use_cache=use_cache,
+            output_hidden_states=True,
+        )
+        if vision_x is None:
+            output['loss'][0] += 0.0 * self.vis_proj(self.vision_encoder.visual(torch.randn(1, 3, 224, 224, device=lang_x.device, dtype=output['loss'].dtype))[1]).mean()
+        
+        hidden_states = output["hidden_states"][-1]
+        if self.training and added_bbox_list is not None:
+            detection_losses, loss_dict = self.get_detection_losses(
+                input_ids=lang_x,
+                hidden_states=hidden_states,
+                added_bbox_list=added_bbox_list,
+            )
+            output["detection_losses"] = detection_losses
+            output["loss_dict"] = loss_dict
+        elif labels is None:
+            boxes, scores = self.get_detection_result(
+                input_ids=lang_x,
+                hidden_states=hidden_states,
+                debug_id=self.debug_id if hasattr(self, "debug_id") else None,
+                debug_mode=debug_mode,
+            )
+            output["boxes"] = boxes
+            output["scores"] = scores
+
+        if clear_conditioned_layers:
+            self.lang_encoder.clear_conditioned_layers()
+        self._condition_attention(None)
+        return output
+
+    def generate(
+        self,
+        vision_x: torch.Tensor,
+        lang_x: torch.Tensor,
+        attention_mask: torch.Tensor = None,
+        added_bbox_list=None,
+        num_beams=1,
+        max_new_tokens=None,
+        temperature=1.0,
+        top_k=0,
+        top_p=1.0,
+        no_repeat_ngram_size=0,
+        prefix_allowed_tokens_fn=None,
+        length_penalty=1.0,
+        num_return_sequences=1,
+        do_sample=False,
+        early_stopping=False,
+        bad_words_ids=None,
+        force_words_ids=None,
+        image_start_index_list=None,
+        image_nums=None,
+        min_length=None,
+        return_dict_in_generate=False,
+        output_hidden_states=False,
+        output_scores=False,
+        logits_processor_list=None,
+        eos_token_id=None,
+    ):
+        """
+        Generate text conditioned on vision and language inputs.
+
+        Args:
+            vision_x (torch.Tensor): Vision input
+                shape (B, T_img, F, C, H, W)
+                images in the same chunk are collated along T_img, and frames are collated along F
+                currently only F=1 is supported (single-frame videos)
+            lang_x (torch.Tensor): Language input
+                shape (B, T_txt)
+            max_length (int, optional): Maximum length of the output. Defaults to None.
+            attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
+            num_beams (int, optional): Number of beams. Defaults to 1.
+            max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
+            temperature (float, optional): Temperature. Defaults to 1.0.
+            top_k (int, optional): Top k. Defaults to 0.
+            top_p (float, optional): Top p. Defaults to 1.0.
+            no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
+            length_penalty (float, optional): Length penalty. Defaults to 1.0.
+            num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
+            do_sample (bool, optional): Do sample. Defaults to False.
+            early_stopping (bool, optional): Early stopping. Defaults to False.
+        Returns:
+            torch.Tensor: lang_x with generated tokens appended to it
+        """
+        if num_beams > 1:
+            vision_x = vision_x.repeat_interleave(num_beams, dim=0)
+            image_start_index_list = torch.tensor(image_start_index_list).repeat_interleave(num_beams, dim=0).tolist()
+            image_nums = torch.tensor(image_nums).repeat_interleave(num_beams, dim=0).tolist()
+            if added_bbox_list is not None and len(added_bbox_list) != 0:
+                added_bbox_list = added_bbox_list * num_beams
+
+        self._encode_vision_x(vision_x=vision_x, image_nums=image_nums, image_start_index_list=image_start_index_list, num_beams=num_beams, added_bbox_list=added_bbox_list, input_ids=lang_x.repeat_interleave(num_beams, dim=0))
+
+        if logits_processor_list is not None:
+            assert isinstance(logits_processor_list, list)
+            logits_processor_list = LogitsProcessorList(logits_processor_list)
+        output = self.lang_encoder.generate(
+            input_ids=lang_x,
+            attention_mask=attention_mask,
+            eos_token_id=(self.eoc_token_id) if eos_token_id is None else eos_token_id,
+            num_beams=num_beams,
+            max_new_tokens=max_new_tokens,
+            min_length=min_length,
+            length_penalty=length_penalty,
+            logits_processor=logits_processor_list,
+            return_dict_in_generate=return_dict_in_generate,
+            output_scores=output_scores,
+        )
+        self.lang_encoder.clear_conditioned_layers()
+        return output
+
+    def _get_data_list_and_visual_tokens(
+        self,
+        all_box_list,
+        box_token_id,
+        prebox_token_id,
+        input_ids,
+        vision_x,
+        nothing_embedding = None,
+    ):
+        box_locations = (torch.logical_or(input_ids == box_token_id, input_ids == prebox_token_id)).nonzero()
+        prev_batch_idx = -1
+        media_idx = []
+        cnt = 0
+        data_list = []
+        visual_tokens = []
+        if len(all_box_list) != len(box_locations):
+            logging.info(f"WARNING. len(all_box_list) != len(box_locations) {len(all_box_list)} vs {len(box_locations)}")
+            self.valid = False
+        for III, (batch_idx, idx) in enumerate(box_locations):
+            batch_idx = batch_idx.item()
+            idx = idx.item()
+            if batch_idx != prev_batch_idx:
+                prev_batch_idx = batch_idx
+                this_input_ids = input_ids[batch_idx]
+                cnt += len(media_idx)
+                media_idx = (this_input_ids == self.media_token_id).nonzero().reshape(-1).tolist()
+            for i in range(len(media_idx)):
+                if i == len(media_idx) - 1 or idx > media_idx[i] and idx < media_idx[i+1]:
+                    break
+            image_index = cnt + i
+            size = int(vision_x[image_index].shape[0] ** 0.5)
+            image_feature = vision_x[image_index].reshape(size, size, -1)
+            try:
+                raw_xyxy = all_box_list[III]
+            except:
+                logging.info("out of scope for all_box_list")
+                raw_xyxy = all_box_list[-1]
+            region_xyxy = np.array(raw_xyxy) * size
+            x1, y1, x2, y2 = region_xyxy.astype(int).clip(0, size-1).tolist()
+            x2 = max(x1, x2)
+            y2 = max(y1, y2)
+            if x1 + y1 + x2 + y2 == 0.0 and nothing_embedding is not None:
+                visual_token = nothing_embedding
+            else:
+                if self.roi_align:
+                    visual_token = torchvision.ops.roi_align(
+                        image_feature.permute(2, 0, 1).unsqueeze(0),
+                        [torch.tensor(region_xyxy.astype(np.float32)).unsqueeze(0).cuda()],
+                        output_size=self.roi_output_size,
+                        spatial_scale=1.0,
+                    )
+                    visual_token = visual_token.squeeze(0).flatten(1).permute(1, 0)
+                else:
+                    visual_token = image_feature[y1:y2+1, x1:x2+1].reshape(-1, image_feature.shape[-1]).mean(0)
+            box = torch.tensor([0] + raw_xyxy, device=visual_token.device, dtype=visual_token.dtype)
+            data_list.append([visual_token, box, batch_idx, idx, i])
+            visual_tokens.append(visual_token)
+        return data_list, visual_tokens
+
+    def _encode_vision_x(self, vision_x: torch.Tensor, image_nums=None, image_start_index_list=None, added_bbox_list=None, num_beams=None, input_ids=None, relations=None):
+        """
+        Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
+        Args:
+            vision_x (torch.Tensor): Vision input
+                shape (B, T_img, F, C, H, W)
+                Images in the same chunk are collated along T_img, and frames are collated along F
+                Currently only F=1 is supported (single-frame videos)
+
+        rearrange code based on https://github.com/dhansmair/flamingo-mini
+        """
+        assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
+        b, T, F = vision_x.shape[:3]
+        assert F == 1, "Only single frame supported"
+
+        vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
+        if hasattr(self.vision_encoder, "visual"):
+            vision_x = self.vision_encoder.visual(vision_x)[1]
+        else:
+            vision_x = self.vision_encoder(vision_x).flatten(2).permute(0, 2, 1)
+        vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
+
+        # print(vision_x[0,0,0])
+        # # DEBUG HERE
+        # if torch.distributed.get_rank() == 0:
+        #     import pdb; pdb.set_trace()
+        # else:
+        #     torch.distributed.barrier()
+        vision_x = vision_x.mean(2)
+        # vision_x = self.perceiver(vision_x)  # reshapes to (b, T, n, d)
+        # vision_x = self.vis_proj(vision_x) + self.vis_position_embedding(self.vis_position_ids).unsqueeze(0)
+        vision_x = self.vis_proj(vision_x).squeeze(1)
+        self.image_embedding = vision_x
+
+        data_list = None
+        visual_tokens = None
+        if added_bbox_list is not None and input_ids is not None:
+            all_box_list = added_bbox_list[0].tolist()
+            for list in added_bbox_list[1:]:
+                all_box_list.extend(list.tolist())
+            data_list, visual_tokens = self._get_data_list_and_visual_tokens(
+                all_box_list=all_box_list,
+                box_token_id=self.box_token_id,
+                prebox_token_id=self.prebox_token_id,
+                input_ids=input_ids,
+                vision_x=vision_x,
+                nothing_embedding=self.lang_encoder.gpt_neox.embed_in(torch.tensor(self.nothing_token_id).to(self.lang_encoder.gpt_neox.embed_in.weight.device)) if self.nothing_token_id is not None else None,
+            )
+
+        first_layer = self.lang_encoder._get_decoder_layers()[0]
+        first_layer.condition_vis_x(vision_x, image_nums, image_start_index_list, num_beams=num_beams, visual_tokens=visual_tokens, data_list=[[d[2], d[3]] for d in data_list] if data_list is not None else data_list)
diff --git a/multimodal/open_flamingo/src/flamingo_lm.py b/multimodal/open_flamingo/src/flamingo_lm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3c1a5f981f95fa22219c73bdcd288165317d13c
--- /dev/null
+++ b/multimodal/open_flamingo/src/flamingo_lm.py
@@ -0,0 +1,173 @@
+import random
+import torch
+import torch.nn as nn
+import numpy as np
+
+from .helpers import GatedCrossAttentionBlock
+from .utils import getattr_recursive, setattr_recursive
+
+
+class FlamingoLayer(nn.Module):
+    def __init__(self, decoder_layer):
+        super().__init__()
+        self.decoder_layer = decoder_layer
+        self.vis_x = None
+        self.image_nums = None
+        self.image_start_index_list = None
+        self.media_locations = None
+        self.add_visual_token = False
+        self.input_ids = None
+
+    def is_conditioned(self) -> bool:
+        """Check whether the layer is conditioned."""
+        return self.vis_x is not None
+
+    # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
+    def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, num_beams=None, visual_tokens=None, data_list=None):
+        self.vis_x = vis_x
+        self.image_nums = image_nums
+        self.image_start_index_list = image_start_index_list
+        self.num_beams = num_beams
+        self.visual_tokens = visual_tokens
+        self.data_list = data_list
+        self.input_ids = None
+
+
+    def condition_media_locations(self, media_locations):
+        self.media_locations = media_locations
+
+    def condition_attend_previous(self, attend_previous):
+        self.attend_previous = attend_previous
+
+    def forward(
+        self,
+        hidden_states, # alignment with hugging face name
+        attention_mask=None,
+        **decoder_layer_kwargs,
+    ):
+        if self.media_locations is None:
+            raise ValueError("media_locations must be conditioned before forward pass")
+
+        if self.vis_x is not None:
+            if self.training:
+                single_length = self.vis_x.shape[-2]
+                image_nums = self.image_nums
+                image_start_index_list = self.image_start_index_list
+                image_nums = [0] + np.cumsum(image_nums).tolist()
+                for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
+                    for index in start_indices:
+                        if image_num_begin < image_num_end:
+                            hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
+                            image_num_begin += 1
+
+                if self.visual_tokens is not None and len(self.visual_tokens) != 0:
+                    for i, (x, y) in enumerate(self.data_list):
+                        if len(self.visual_tokens[i].shape) > 1:
+                            # print(self.visual_tokens[i].shape[0], "embedding")
+                            hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
+                        else:
+                            # print(self.visual_tokens[i].shape[0], "embedding")
+                            hidden_states[x, y] = self.visual_tokens[i]
+
+            elif not self.training:
+                if (
+                    ("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or
+                    ("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None)
+                ):
+                    single_length = self.vis_x.shape[-2]
+                    image_nums = self.image_nums
+                    image_start_index_list = self.image_start_index_list
+                    image_nums = [0] + np.cumsum(image_nums).tolist()
+                    for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)):
+                        for index in start_indices:
+                            if image_num_begin < image_num_end:
+                                hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin]
+                                image_num_begin += 1
+                    if self.visual_tokens is not None and len(self.visual_tokens) != 0:
+                        for i, (x, y) in enumerate(self.data_list):
+                            # import pdb; pdb.set_trace()
+                            # print(x, y, self.visual_tokens[i].shape)
+                            if len(self.visual_tokens[i].shape) > 1:
+                                # print(self.visual_tokens[i].shape[0], "embedding")
+                                hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i]
+                            else:
+                                # print(self.visual_tokens[i].shape[0], "embedding")
+                                hidden_states[x, y] = self.visual_tokens[i]
+        hidden_states = self.decoder_layer(
+            hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs
+        )
+        return hidden_states
+
+
+class FlamingoLMMixin(nn.Module):
+    """
+    Mixin to add cross-attention layers to a language model.
+    """
+
+    def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
+        self.decoder_layers_attr_name = decoder_layers_attr_name
+
+    def _get_decoder_layers(self):
+        return getattr_recursive(self, self.decoder_layers_attr_name)
+
+    def _set_decoder_layers(self, value):
+        setattr_recursive(self, self.decoder_layers_attr_name, value)
+
+    def init_flamingo(
+        self,
+        media_token_id,
+        use_media_placement_augmentation,
+    ):
+        """
+        Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
+        """
+        self._set_decoder_layers(
+            nn.ModuleList(
+                [FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()]
+            )
+        )
+        self.media_token_id = media_token_id
+        self.use_media_placement_augmentation = use_media_placement_augmentation
+        self.initialized_flamingo = True
+
+    def forward(self, *input, **kwargs):
+        """Condition the Flamingo layers on the media locations before forward()"""
+        if not self.initialized_flamingo:
+            raise ValueError(
+                "Flamingo layers are not initialized. Please call `init_flamingo` first."
+            )
+
+        input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
+        media_locations = input_ids == self.media_token_id
+        attend_previous = (
+            (random.random() < 0.5) if self.use_media_placement_augmentation else True
+        )
+
+        if (
+            "gpt2" in self.__class__.__name__.lower()
+            or "codegen" in self.__class__.__name__.lower()
+        ):
+            for layer in self.transformer.h:
+                layer.condition_media_locations(media_locations)
+                layer.condition_attend_previous(attend_previous)
+        elif "gptneox" in self.__class__.__name__.lower():
+            for layer in self.gpt_neox.layers:
+                layer.condition_media_locations(media_locations)
+                layer.condition_attend_previous(attend_previous)
+        else:
+            for layer in self.get_decoder().layers:
+                layer.condition_media_locations(media_locations)
+                layer.condition_attend_previous(attend_previous)
+        return super().forward(
+            *input, **kwargs
+        )  # Call the other parent's forward method
+
+    def is_conditioned(self) -> bool:
+        """Check whether all decoder layers are already conditioned."""
+        return all(l.is_conditioned() for l in self._get_decoder_layers())
+
+    def clear_conditioned_layers(self):
+        for layer in self._get_decoder_layers():
+            layer.condition_vis_x(None)
+            layer.condition_media_locations(None)
+            layer.condition_attend_previous(None)
diff --git a/multimodal/open_flamingo/src/gcn.py b/multimodal/open_flamingo/src/gcn.py
new file mode 100644
index 0000000000000000000000000000000000000000..25794bb8b2600b01137cf77b8336a9e8a21e7922
--- /dev/null
+++ b/multimodal/open_flamingo/src/gcn.py
@@ -0,0 +1,137 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+import math
+from torch.autograd import Variable
+from torchvision.ops import box_iou
+
+
+
+class GraphConvolution(nn.Module):
+    """
+    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
+    """
+
+    def __init__(self, in_features, out_features, bias=True, skip=True):
+        super(GraphConvolution, self).__init__()
+        self.skip = skip
+        self.in_features = in_features
+        self.out_features = out_features
+        self.weight = Parameter(torch.Tensor(in_features, out_features))
+        if bias:
+            self.bias = Parameter(torch.Tensor(out_features))
+        else:
+            self.register_parameter('bias', None)
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        stdv = 1. / math.sqrt(self.weight.size(1))
+        self.weight.data.uniform_(-stdv, stdv)
+        if self.bias is not None:
+            self.bias.data.uniform_(-stdv, stdv)
+
+    def forward(self, input, adj):
+        # TODO make fc more efficient via "pack_padded_sequence"
+        # import ipdb; ipdb.set_trace()
+        support = torch.bmm(input, self.weight.unsqueeze(
+            0).expand(input.shape[0], -1, -1))
+        output = torch.bmm(adj, support)
+        #output = SparseMM(adj)(support)
+        if self.bias is not None:
+            output += self.bias.unsqueeze(0).expand(input.shape[0], -1, -1)
+        if self.skip:
+            output += support
+
+        return output
+
+    def __repr__(self):
+        return self.__class__.__name__ + ' (' \
+            + str(self.in_features) + ' -> ' \
+            + str(self.out_features) + ')'
+
+
+class GCN_sim(nn.Module):
+    def __init__(self, dim_in, dim_hidden, dim_out, dropout, num_layers):
+        super(GCN_sim, self).__init__()
+        assert num_layers >= 1
+        self.fc_k = nn.Linear(dim_in, dim_hidden)
+        self.fc_q = nn.Linear(dim_in, dim_hidden)
+
+        dim_hidden = dim_out if num_layers == 1 else dim_hidden
+        self.gcs = nn.ModuleList([
+            GraphConvolution(dim_in, dim_hidden)
+        ])
+
+        for i in range(num_layers - 1):
+            dim_tmp = dim_out if i == num_layers-2 else dim_hidden
+            self.gcs.append(GraphConvolution(dim_hidden, dim_tmp))
+
+        self.dropout = dropout
+
+    def construct_graph(self, x, length):
+        # TODO make fc more efficient via "pack_padded_sequence"
+        emb_k = self.fc_k(x)
+        emb_q = self.fc_q(x)
+
+        s = torch.bmm(emb_k, emb_q.transpose(1, 2))
+
+        s_mask = s.data.new(*s.size()).fill_(1).bool()  # [B, T1, T2]
+        # Init similarity mask using lengths
+        for i, (l_1, l_2) in enumerate(zip(length, length)):
+            s_mask[i][:l_1, :l_2] = 0
+        s_mask = Variable(s_mask)
+        s.data.masked_fill_(s_mask.data, -float("inf"))
+
+        a_weight = F.softmax(s, dim=2)  # [B, t1, t2]
+        # remove nan from softmax on -inf
+        a_weight.data.masked_fill_(a_weight.data != a_weight.data, 0)
+
+        return a_weight
+
+    def forward(self, x, length):
+        adj_sim = self.construct_graph(x, length)
+
+        for gc in self.gcs:
+            x = F.relu(gc(x, adj_sim))
+            x = F.dropout(x, self.dropout, training=self.training)
+
+        return x
+
+
+class GCN(nn.Module):
+    def __init__(self, dim_in, dim_hidden, dim_out, dropout, mode, skip, num_layers, ST_n_next=None):
+        super(GCN, self).__init__()
+        assert len(mode) != 0
+        self.mode = mode
+        self.skip = skip
+
+        if "GCN_sim" in mode:
+            self.GCN_sim = GCN_sim(
+                dim_in, dim_hidden, dim_out, dropout, num_layers)
+
+    def forward(self, x, length):
+
+        out = []
+        if "GCN_sim" in self.mode:
+            out.append(self.GCN_sim(x, length))
+
+        out = sum(out)
+        if self.skip:
+            out += x
+
+        return out
+
+
+if __name__ == '__main__':
+    model = GCN(512, 128, 512, 0.5, mode=[
+                "GCN_sim"], skip=True, num_layers=3, ST_n_next=3)
+    bs, T, N = 10, 5, 10
+    n_node = T*N
+
+    input = torch.rand(bs, n_node, 512)
+    length = torch.ones((bs))
+    length = length.type(torch.IntTensor)
+    bboxes = torch.rand((bs, 5, 10, 4))
+
+    output = model(input, length)
diff --git a/multimodal/open_flamingo/src/helpers.py b/multimodal/open_flamingo/src/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..41fa2c073e5d3141a08dac7f109bb60a682bfe14
--- /dev/null
+++ b/multimodal/open_flamingo/src/helpers.py
@@ -0,0 +1,263 @@
+"""
+Taken from https://github.com/lucidrains/flamingo-pytorch
+"""
+
+import torch
+from einops import rearrange, repeat
+from einops_exts import rearrange_many
+from torch import einsum, nn
+
+
+def exists(val):
+    return val is not None
+
+
+def FeedForward(dim, mult=4):
+    inner_dim = int(dim * mult)
+    return nn.Sequential(
+        nn.LayerNorm(dim),
+        nn.Linear(dim, inner_dim, bias=False),
+        nn.GELU(),
+        nn.Linear(inner_dim, dim, bias=False),
+    )
+
+
+class PerceiverAttention(nn.Module):
+    def __init__(self, *, dim, dim_head=64, heads=8):
+        super().__init__()
+        self.scale = dim_head**-0.5
+        self.heads = heads
+        inner_dim = dim_head * heads
+
+        self.norm_media = nn.LayerNorm(dim)
+        self.norm_latents = nn.LayerNorm(dim)
+
+        self.to_q = nn.Linear(dim, inner_dim, bias=False)
+        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+        self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+    def forward(self, x, latents):
+        """
+        Args:
+            x (torch.Tensor): image features
+                shape (b, T, n1, D)
+            latent (torch.Tensor): latent features
+                shape (b, T, n2, D)
+        """
+        x = self.norm_media(x)
+        latents = self.norm_latents(latents)
+
+        h = self.heads
+
+        q = self.to_q(latents)
+        kv_input = torch.cat((x, latents), dim=-2)
+        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+        q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
+        q = q * self.scale
+
+        # attention
+        sim = einsum("... i d, ... j d  -> ... i j", q, k)
+        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+        attn = sim.softmax(dim=-1)
+
+        out = einsum("... i j, ... j d -> ... i d", attn, v)
+        out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
+        return self.to_out(out)
+
+
+class PerceiverResampler(nn.Module):
+    def __init__(
+        self,
+        *,
+        dim,
+        depth=6,
+        dim_head=64,
+        heads=8,
+        num_latents=64,
+        max_num_media=None,
+        max_num_frames=None,
+        ff_mult=4,
+    ):
+        super().__init__()
+        assert False, "Do not use PerceiverResampler"
+        self.latents = nn.Parameter(torch.randn(num_latents, dim))
+        self.frame_embs = (
+            nn.Parameter(torch.randn(max_num_frames, dim))
+            if exists(max_num_frames)
+            else None
+        )
+        self.media_time_embs = (
+            nn.Parameter(torch.randn(max_num_media, 1, dim))
+            if exists(max_num_media)
+            else None
+        )
+
+        self.layers = nn.ModuleList([])
+        for _ in range(depth):
+            self.layers.append(
+                nn.ModuleList(
+                    [
+                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+                        FeedForward(dim=dim, mult=ff_mult),
+                    ]
+                )
+            )
+
+        self.norm = nn.LayerNorm(dim)
+
+    def forward(self, x):
+        """
+        Args:
+            x (torch.Tensor): image features
+                shape (b, T, F, v, D)
+        Returns:
+            shape (b, T, n, D) where n is self.num_latents
+        """
+        b, T, F, v = x.shape[:4]
+
+        # frame and media time embeddings
+        if exists(self.frame_embs):
+            frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
+            x = x + frame_embs
+        x = rearrange(
+            x, "b T F v d -> b T (F v) d"
+        )  # flatten the frame and spatial dimensions
+        if exists(self.media_time_embs):
+            x = x + self.media_time_embs[:T]
+
+        # blocks
+        latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
+        for attn, ff in self.layers:
+            latents = attn(x, latents) + latents
+            latents = ff(latents) + latents
+        return self.norm(latents)
+
+
+# gated cross attention
+
+
+class MaskedCrossAttention(nn.Module):
+    def __init__(
+        self,
+        *,
+        dim,
+        dim_visual,
+        dim_head=64,
+        heads=8,
+        only_attend_immediate_media=True,
+    ):
+        super().__init__()
+        self.scale = dim_head**-0.5
+        self.heads = heads
+        inner_dim = dim_head * heads
+
+        self.norm = nn.LayerNorm(dim)
+
+        self.to_q = nn.Linear(dim, inner_dim, bias=False)
+        self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
+        self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+        # whether for text to only attend to immediate preceding image, or all previous images
+        self.only_attend_immediate_media = only_attend_immediate_media
+
+    def forward(self, x, media, media_locations=None, attend_previous=True):
+        """
+        Args:
+            x (torch.Tensor): text features
+                shape (B, T_txt, D_txt)
+            media (torch.Tensor): image features
+                shape (B, T_img, n, D_img) where n is the dim of the latents
+            media_locations: boolean mask identifying the media tokens in x
+                shape (B, T_txt)
+            attend_previous: bool
+                If false, ignores immediately preceding image and starts attending when following image
+        """
+        assert attend_previous, "text must attend to the image that before it"
+
+        _, T_img, n = media.shape[:3]
+        h = self.heads
+
+        x = self.norm(x)
+
+        q = self.to_q(x)
+        media = rearrange(media, "b t n d -> b (t n) d")
+
+        k, v = self.to_kv(media).chunk(2, dim=-1)
+        q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
+
+        q = q * self.scale
+
+        sim = einsum("... i d, ... j d -> ... i j", q, k)
+
+        if exists(media_locations):
+            # at each boolean of True, increment the time counter (relative to media time)
+            text_time = media_locations.cumsum(dim=-1)
+            media_time = torch.arange(T_img, device=x.device) + 1
+
+            if not attend_previous:
+                text_time[~media_locations] += 1
+                # make sure max is still the number of images in the sequence
+                text_time[
+                    text_time
+                    > repeat(
+                        torch.count_nonzero(media_locations, dim=1),
+                        "b -> b i",
+                        i=text_time.shape[1],
+                    )
+                ] = 0
+
+            # text time must equal media time if only attending to most immediate image
+            # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
+            mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
+
+            text_to_media_mask = mask_op(
+                rearrange(text_time, "b i -> b 1 i 1"),
+                repeat(media_time, "j -> 1 1 1 (j n)", n=n),
+            )
+            sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
+
+        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
+        attn = sim.softmax(dim=-1)
+
+        if exists(media_locations) and self.only_attend_immediate_media:
+            # any text without a preceding media needs to have attention zeroed out
+            text_without_media_mask = text_time == 0
+            text_without_media_mask = rearrange(
+                text_without_media_mask, "b i -> b 1 i 1"
+            )
+            attn = attn.masked_fill(text_without_media_mask, 0.0)
+
+        out = einsum("... i j, ... j d -> ... i d", attn, v)
+        out = rearrange(out, "b h n d -> b n (h d)")
+        return self.to_out(out)
+
+
+class GatedCrossAttentionBlock(nn.Module):
+    def __init__(
+        self,
+        *,
+        dim,
+        dim_visual,
+        dim_head=64,
+        heads=8,
+        ff_mult=4,
+        only_attend_immediate_media=True,
+    ):
+        super().__init__()
+        self.attn = MaskedCrossAttention(
+            dim=dim,
+            dim_visual=dim_visual,
+            dim_head=dim_head,
+            heads=heads,
+            only_attend_immediate_media=only_attend_immediate_media,
+        )
+
+    def forward(
+        self,
+        x,
+        media,
+        media_locations=None,
+        attend_previous=True,
+    ):
+        x = self.attn(x, media, media_locations=media_locations, attend_previous=attend_previous) + x
+        return x
diff --git a/multimodal/open_flamingo/src/utils.py b/multimodal/open_flamingo/src/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..815c70016c33ca9133aba60811a4948e31a2df27
--- /dev/null
+++ b/multimodal/open_flamingo/src/utils.py
@@ -0,0 +1,31 @@
+def extend_instance(obj, mixin):
+    """Apply mixins to a class instance after creation"""
+    base_cls = obj.__class__
+    base_cls_name = obj.__class__.__name__
+    obj.__class__ = type(
+        base_cls_name, (mixin, base_cls), {}
+    )  # mixin needs to go first for our forward() logic to work
+
+
+def getattr_recursive(obj, att):
+    """
+    Return nested attribute of obj
+    Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
+    """
+    if att == "":
+        return obj
+    i = att.find(".")
+    if i < 0:
+        return getattr(obj, att)
+    else:
+        return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
+
+
+def setattr_recursive(obj, att, val):
+    """
+    Set nested attribute of obj
+    Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
+    """
+    if "." in att:
+        obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
+    setattr(obj, att.split(".")[-1], val)
diff --git a/multimodal/open_flamingo/train/__init__.py b/multimodal/open_flamingo/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/multimodal/open_flamingo/train/__init__.py
@@ -0,0 +1 @@
+
diff --git a/multimodal/open_flamingo/train/data2.py b/multimodal/open_flamingo/train/data2.py
new file mode 100644
index 0000000000000000000000000000000000000000..1406df3b3af066fff71860f34708015b9778cc2a
--- /dev/null
+++ b/multimodal/open_flamingo/train/data2.py
@@ -0,0 +1,868 @@
+import functools
+import logging
+import math
+import random
+import sys
+from dataclasses import dataclass
+from multiprocessing import Value
+import time
+import os
+import numpy as np
+import pickle as pkl
+from open_flamingo.train.instruction_template import (
+    VG_RELATION_TEMPLATES,
+    PISC_TEMPLATES,
+)
+
+import torch
+import webdataset as wds
+from PIL import Image
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
+from torch.utils.data.distributed import DistributedSampler
+from webdataset.tariterators import (
+    base_plus_ext,
+    tar_file_expander,
+    url_opener,
+    valid_sample,
+)
+
+from groundingdino.demo.caption_grounder import caption_grounder
+from groundingdino.demo.inference_on_laion import add_loc_to_text
+from groundingdino.demo.inference_on_laion import nms_without_score
+from groundingdino.demo.inference_on_laion import calculate_iou
+
+Image.MAX_IMAGE_PIXELS = 1000000000
+LAION2B_NUM_SAMPLE = 1500000000
+VQAV2_TRAIN_NUM_SAMPLE = 1828467
+VG_RELATION_BBOX_SIZE = 600
+
+REL_LABELS = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind', 'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for', 'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on', 'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over', 'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on', 'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']
+
+try:
+    import horovod.torch as hvd
+except ImportError:
+    hvd = None
+
+class ConcatDataset(IterableDataset):
+    def __init__(
+            self, dataset, max_length,
+            delimiter_id, pad_id=None, media_id=None, endofmedia_id=None,
+            image_embedding_size=-2, single=False, box_id=None, visual_id=None,
+        ):
+        self.dataset = dataset
+        self.max_length = max_length
+        self.delimiter_id = torch.ones(1,1).long() * delimiter_id
+        if pad_id is not None:
+            self.pad_id = int(pad_id)
+        if media_id is not None:
+            self.media_id = torch.ones(1,1).long() * int(media_id)
+        if endofmedia_id is not None:
+            self.endofmedia_id = torch.ones(1,1).long() * int(endofmedia_id)
+        if image_embedding_size > 0:
+            logging.info(f"image_embedding_size: {image_embedding_size}")
+        self.image_embedding_size = image_embedding_size + 2
+        self.single = single
+        self.box_id = box_id
+        self.visual_id = visual_id
+    
+    def __iter__(self):
+        while True:
+            input_ids_list = []
+            attention_mask_list = []
+            image_list = []
+            image_start_index_list = []
+            added_bbox_list = []
+            relations_list = []
+            cnt = 0
+            while cnt < self.max_length:
+                sample = next(self.dataset)
+                if len(sample) >= 4:
+                    image = sample[0].unsqueeze(0)
+                    input_ids = sample[1]
+                    attention_mask = sample[2]
+                    added_bbox = sample[3]
+                    image_list.append(image)
+                    added_bbox_list.append(added_bbox)
+                    if len(sample) == 5:
+                        relations_list.append(sample[4])
+                else:
+                    sample = sample[0]
+                    input_ids = sample[0]
+                    attention_mask = sample[1]
+                input_ids_list.append(input_ids)
+                attention_mask_list.append(attention_mask)
+                cnt += input_ids.shape[-1]
+                if self.single:
+                    break
+            input_ids = torch.cat(input_ids_list, dim=-1)[0]
+            attention_mask = torch.cat(attention_mask_list, dim=-1)[0]
+            if not self.single:
+                input_ids = input_ids[:self.max_length]
+                attention_mask = attention_mask[:self.max_length]
+            # TODO: fix visual number not match
+            if len(image_list) != 0:
+                images = torch.cat(image_list, dim=0)
+                image_begin = (input_ids == self.media_id[0,0]).nonzero().view(-1)
+                image_end = (input_ids == self.endofmedia_id[0,0]).nonzero().view(-1)
+                if len(image_begin) != len(image_end):
+                    assert len(image_begin) == len(image_end) + 1
+                    input_ids[image_begin[-1]:] = self.pad_id
+                    attention_mask[image_begin[-1]:] = 0
+                    image_begin = image_begin[:-1]
+                eos_token_num = len((input_ids == self.delimiter_id[0,0]).nonzero().view(-1))
+                if eos_token_num != len(image_begin) + 1:
+                    input_ids[image_begin[-1]:] = self.pad_id
+                    attention_mask[image_begin[-1]:] = 0
+                    image_begin = image_begin[:-1]
+                    image_end = image_end[:-1]
+                images = images[:len(image_end)]
+                added_bbox_list = added_bbox_list[:len(image_end)]
+                relations_list = relations_list[:len(image_end)]
+                image_start_index_list = (image_begin + 1).tolist()
+                expand_list = added_bbox_list[0]
+                for x in added_bbox_list[1:]:
+                    expand_list.extend(x)
+                yield images, len(images), image_start_index_list, input_ids, attention_mask, expand_list, relations_list
+            else:
+                yield input_ids, attention_mask
+
+
+class SharedEpoch:
+    def __init__(self, epoch: int = 0):
+        self.shared_epoch = Value("i", epoch)
+
+    def set_value(self, epoch):
+        self.shared_epoch.value = epoch
+
+    def get_value(self):
+        return self.shared_epoch.value
+
+
+@dataclass
+class DataInfo:
+    dataloader: DataLoader
+    sampler: DistributedSampler = None
+    shared_epoch: SharedEpoch = None
+
+    def set_epoch(self, epoch):
+        if self.shared_epoch is not None:
+            self.shared_epoch.set_value(epoch)
+        if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
+            self.sampler.set_epoch(epoch)
+
+
+def filter_no_caption_or_no_image(sample):
+    return ("txt" in sample) and (
+        "png" in sample or "jpg" in sample or "jpeg" in sample
+    )
+
+
+def log_and_continue(exn):
+    """Call in an exception handler to ignore any exception, issue a warning, and continue."""
+    if "ValueError" in repr(exn) or "KeyError" in repr(exn):  # Avoid spamming logs with these
+        return True
+    logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
+    return True
+# DEBUG
+# log_and_continue = None
+# DEBUG
+
+
+def group_by_keys_nothrow(
+    data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
+):
+    """Return function over iterator that groups key, value pairs into samples.
+
+    :param keys: function that splits the key into key and extension (base_plus_ext)
+    :param lcase: convert suffixes to lower case (Default value = True)
+    """
+    current_sample = None
+    tar_idx = None
+    for filesample in data:
+        assert isinstance(filesample, dict)
+        current_tar_idx = filesample["__url__"].split("/")[-1].split(".")[0]
+        if current_tar_idx != tar_idx:
+            tar_idx = current_tar_idx
+            if "blip2_all_data_ground" in filesample["__url__"]:
+                relation_data_dir = os.path.join("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_relation", tar_idx)
+                missing_file = False
+                try:
+                    data_info = pkl.load(open(os.path.join(relation_data_dir, "custom_data_info.pkl"), "rb"))
+                    prediction = pkl.load(open(os.path.join(relation_data_dir, "custom_prediction.pkl"), "rb"))
+                    idx_to_files = data_info["idx_to_files"]
+                    ind_to_classes = data_info["ind_to_classes"]
+                    ind_to_predicates = data_info["ind_to_predicates"]
+                    files_to_idx = {x.split("#")[-1]: i for i, x in enumerate(idx_to_files)}
+                except:
+                    missing_file = True
+        fname, value = filesample["fname"], filesample["data"]
+        prefix, suffix = keys(fname)
+        if prefix is None:
+            continue
+        if lcase:
+            suffix = suffix.lower()
+        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+        if (
+            current_sample is None
+            or prefix != current_sample["__key__"]
+            or suffix in current_sample
+        ):
+            if valid_sample(current_sample):
+                yield current_sample
+            current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
+            if "blip2_all_data_ground" in filesample["__url__"] and not missing_file:
+                try:
+                    idx = files_to_idx[prefix]
+                    prediction[idx]["bbox"] = [np.array(bbox)/VG_RELATION_BBOX_SIZE for bbox in prediction[idx]["bbox"]]
+                    current_sample["relation_data"] = prediction[idx]
+                except:
+                    current_sample["relation_data"] = dict()
+            else:
+                current_sample["relation_data"] = dict()
+        if suffixes is None or suffix in suffixes:
+            current_sample[suffix] = value
+    if valid_sample(current_sample):
+        yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=log_and_continue):
+    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+    streams = url_opener(src, handler=handler)
+    files = tar_file_expander(streams, handler=handler)
+    samples = group_by_keys_nothrow(files, handler=handler)
+    return samples
+
+
+def pytorch_worker_seed(increment=0):
+    """get dataloader worker seed from pytorch"""
+    worker_info = get_worker_info()
+    if worker_info is not None:
+        # favour using the seed already created for pytorch dataloader workers if it exists
+        seed = worker_info.seed
+        if increment:
+            # space out seed increments so they can't overlap across workers in different iterations
+            seed += increment * max(1, worker_info.num_workers)
+        return seed
+    # fallback to wds rank based seed
+    return wds.utils.pytorch_worker_seed()
+
+
+_SHARD_SHUFFLE_SIZE = 2000
+_SHARD_SHUFFLE_INITIAL = 500
+_SAMPLE_SHUFFLE_SIZE = 5000
+_SAMPLE_SHUFFLE_INITIAL = 1000
+
+
+class ResampledShards2(IterableDataset):
+    """An iterable dataset yielding a list of urls."""
+
+    def __init__(
+        self,
+        urls,
+        nshards=sys.maxsize,
+        worker_seed=None,
+        deterministic=False,
+        epoch=-1,
+    ):
+        """Sample shards from the shard list with replacement.
+        :param urls: a list of URLs as a Python list or brace notation string
+        """
+        super().__init__()
+        urls = wds.shardlists.expand_urls(urls)
+        self.urls = urls
+        assert isinstance(self.urls[0], str)
+        self.nshards = nshards
+        self.rng = random.Random()
+        self.worker_seed = worker_seed
+        self.deterministic = deterministic
+        self.epoch = epoch
+
+    def __iter__(self):
+        """Return an iterator over the shards."""
+        if isinstance(self.epoch, SharedEpoch):
+            epoch = self.epoch.get_value()
+        else:
+            # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
+            # situation as different workers may wrap at different times (or not at all).
+            self.epoch += 1
+            epoch = self.epoch
+
+        if self.deterministic:
+            # reset seed w/ epoch if deterministic
+            if self.worker_seed is None:
+                # pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
+                seed = pytorch_worker_seed(epoch)
+            else:
+                seed = self.worker_seed() + epoch
+            seed = seed + int(time.time())
+            self.rng.seed(seed)
+            # logging.info(f"epoch: {epoch} seed: {seed}")
+        self.rng.shuffle(self.urls)
+        # logging.info(f"{len(self.urls)} | {self.urls[:2]}")
+        for url in self.urls:
+            # logging.info(f"{seed}: {url}")
+            yield dict(url=url)
+
+
+def preprocess_image(sample, image_processor):
+    image = image_processor(sample)
+    return image
+
+
+def preprocess_text(sample, tokenizer, max_length, single=False):
+    if not single:
+        text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True)
+    else:
+        text = tokenizer(tokenizer.bos_token+sample.strip(), return_tensors="pt", max_length=max_length, truncation=True, padding='max_length')
+    return text["input_ids"], text["attention_mask"]
+
+
+def preprocess_encoded_text(sample, tokenizer, max_length):
+    sample = sample.decode("utf-8")
+    return preprocess_text(sample, tokenizer, max_length=max_length)
+
+
+def _merge_bbox_previsual(added_bbox_list):
+    bbox_list = []
+    for bboxes in added_bbox_list:
+        x1 = bboxes[:, 0].min()
+        y1 = bboxes[:, 1].min()
+        x2 = bboxes[:, 2].max()
+        y2 = bboxes[:, 3].max()
+        bbox_list.append(torch.tensor([x1, y1, x2, y2], device=bboxes.device, dtype=bboxes.dtype).unsqueeze(0))
+    return bbox_list
+
+
+def _find_idx(text, subtext):
+    loc = 0
+    locs = []
+    while text.find(subtext, loc) != -1:
+        loc = text.find(subtext, loc)
+        locs.append(loc)
+        loc += len(subtext)
+    return locs
+
+def preprocess_ground_caption(sample, image_processor, tokenizer, image_embedding_size, generator, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None, args=None):
+    assert max_length is not None
+    assert not single, "single is not supported for preprocess_ground_caption"
+    image, caption, logits_filt, boxes_filt, relation_data = sample
+    if len(logits_filt.shape) == 1 and logits_filt.shape[0] == 4 and len(boxes_filt.shape) == 1 and boxes_filt.shape[0] == 4:
+        raise NotImplementedError # lack relation data
+        return preprocess_visual_genome(sample=sample, image_processor=image_processor, tokenizer=tokenizer, image_embedding_size=image_embedding_size, prob_ground=prob_ground, single=single, use_format_v2=use_format_v2, add_visual_token=add_visual_token, max_length=max_length)
+    image = preprocess_image(image, image_processor=image_processor)
+    added_bbox = []
+    if (prob_ground != 0 and random.random() <= prob_ground) or prob_ground == 1.0:
+        boxes_filt, pred_phrases = generator.postprocess(logits_filt, boxes_filt, generator.ground_model, caption, generator.text_threshold, generator.box_threshold, with_logits=True)
+        caption, added_bbox = add_loc_to_text(
+            boxes_filt, pred_phrases, caption,
+            expand=args.expand, always_expand=args.longer_previsual,
+        )
+    visual_loc = []
+    obj_loc = []
+    endofobj_loc = []
+    visual_token = "<|#visual#|>"
+    previsual_token = "<|#previsual#|>"
+    box_token = "<|#box#|>"
+    prebox_token = "<|#prebox#|>"
+    end_token = "<|#endofobject#|>"
+    object_token = "<|#object#|>"
+    end_of_attr_token = "<|#endofattr#|>"
+    preend_of_attr_token = "<|#preendofattr#|>"
+    visual_loc = _find_idx(caption, visual_token)
+    try:
+        if len(visual_loc) != len(added_bbox):
+            logging.warning(f"visual_loc: {visual_loc}")
+            logging.warning(f"added_bbox: {added_bbox}")
+    except:
+        pass
+    assert len(visual_loc) == len(added_bbox)
+    delta = 0
+    for i, (loc, boxes) in enumerate(zip(visual_loc, added_bbox)):
+        loc += delta
+        boxes = nms_without_score(boxes)
+        added_bbox[i] = boxes
+        added_tokens = end_token + visual_token + box_token * len(boxes) + end_of_attr_token
+        caption = caption[:loc] + added_tokens + caption[len(visual_token) + loc:]
+        delta += len(added_tokens) - len(visual_token)
+
+    if use_format_v2:
+        merge_added_bbox = _merge_bbox_previsual(added_bbox)
+        # step 1: move <|#object#|> before the space char
+        while caption.find(f" {object_token}") != -1:
+            caption = caption.replace(f" {object_token}", f"{object_token} ")
+        # step 2: add <|#previsual#|> after <|#object#|> for 75% except the first object
+        i = 0
+        II = -1
+        if args.no_visual:
+            flag = False
+            delete_visual_prob = 10.0
+        else:
+            flag = True
+            delete_visual_prob = 0.75
+        while i < len(caption):
+            if caption[i: i + len(object_token)] == object_token:
+                II += 1
+                if (not args.longer_previsual and not flag and random.random() < delete_visual_prob) or (args.longer_previsual and (flag or random.random() < delete_visual_prob)):
+                    # delete visual and add previsual
+                    visual_start_idx = caption.find(end_token, i+1) + len(end_token)
+                    visual_end_idx = caption.find(end_of_attr_token, visual_start_idx+1) + len(end_of_attr_token)
+                    caption = caption[:visual_start_idx] + caption[visual_end_idx:]
+                    caption = caption[:i + len(object_token)] + previsual_token + prebox_token + preend_of_attr_token + caption[i + len(object_token):]
+                    added_bbox[II] = merge_added_bbox[II]
+            i += 1
+            flag = False
+        if args.no_previsual and args.no_visual:
+            caption = caption.replace(previsual_token, "").replace(prebox_token, "").replace(preend_of_attr_token, "")
+            added_bbox = []
+        caption = caption.replace(preend_of_attr_token, object_token).replace(end_of_attr_token, end_token)
+
+
+    if args.roi_align:
+        i = 0
+        pad_num = args.roi_output_size ** 2 - 1
+        while i < len(caption):
+            if caption[i: i + len(prebox_token)] == prebox_token:
+                caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:]
+                i += len(tokenizer.pad_token) * pad_num + len(prebox_token)
+            elif caption[i: i + len(box_token)] == box_token:
+                caption = caption[:i] + tokenizer.pad_token * pad_num + caption[i:]
+                i += len(tokenizer.pad_token) * pad_num + len(box_token)
+            i += 1
+
+    caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption
+    input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length)
+    relations = []
+    if args.only_grounded_sample and "<|#visual#|>" not in caption:
+        raise ValueError
+    return image, input_ids, attention_mask, added_bbox, relations
+
+
+def preprocess_visual_genome(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None):
+    assert max_length is not None
+    assert not single, "single is not supported for preprocess_ground_caption"
+    image, caption, xyxy, _ = sample
+    image = preprocess_image(image, image_processor=image_processor)
+    caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|><|#object#|>" + caption.strip() + "<|#endofobject#|><|#visual#|><|#box#|><|#endofattr#|>"
+    input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length)
+    added_bbox = [torch.tensor(np.expand_dims(xyxy, 0).astype(np.float32) / 224)]
+    return image, input_ids, attention_mask, added_bbox
+
+special_predicate = [
+    "and",
+    "has",
+    "says",
+    "wears",
+]
+
+original_predicate = {
+    "and": "and",
+    "has": "have",
+    "says": "say",
+    "wears": "wear",
+}
+
+
+def generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation):
+    if relation in ["and", "of"]:
+        id = 0
+    else:
+        id = random.choice(range(len(VG_RELATION_TEMPLATES)))
+    text = VG_RELATION_TEMPLATES[id].format(nameA=nameA, nameB=nameB, relation=relation, use_is="is" if relation not in special_predicate else "", is_or_does="is" if relation not in special_predicate else "does", relation_do=relation if relation not in special_predicate else original_predicate[relation])
+    if id in [0]:
+        added_bbox = [
+            torch.tensor([boxA]),
+            torch.tensor([boxB]),
+        ]
+    elif id in [1]:
+        added_bbox = [
+            torch.tensor([boxA]),
+            torch.tensor([boxB]),
+            torch.tensor([boxA]),
+            torch.tensor([boxB]),
+        ]
+    elif id in [2]:
+        added_bbox = [
+            torch.tensor([boxA]),
+            torch.tensor([boxA]),
+            torch.tensor([boxB]),
+        ]
+    elif id in [3]:
+        added_bbox = [
+            torch.tensor([boxB]),
+            torch.tensor([boxA]),
+            torch.tensor([boxB]),
+        ]
+    elif id in [4]:
+        added_bbox = [
+            torch.tensor([boxA]),
+            torch.tensor([boxB]),
+        ]
+    elif id in [5]:
+        added_bbox = [
+            torch.tensor([boxB]),
+            torch.tensor([boxA]),
+        ]
+    else:
+        raise NotImplementedError
+    return text, added_bbox
+
+def generate_pisc_sample(boxA, boxB, relation):
+    id = random.choice(range(len(PISC_TEMPLATES)))
+    text = PISC_TEMPLATES[id].format(relation=relation)
+    if id in [0]:
+        if random.random() < 0.5:
+            added_bbox = [
+                torch.tensor([boxA]),
+                torch.tensor([boxB]),
+            ]
+        else:
+            added_bbox = [
+                torch.tensor([boxB]),
+                torch.tensor([boxA]),
+            ]
+    elif id in [1]:
+        if random.random() < 0.5:
+            added_bbox = [torch.tensor([boxA, boxB])]
+        else:
+            added_bbox = [torch.tensor([boxB, boxA])]
+    return text, added_bbox
+
+
+def preprocess_instruct(sample, image_processor, tokenizer, image_embedding_size, prob_ground=1.0, single=False, use_format_v2=False, add_visual_token=False, max_length=None):
+    image_path, dataset, data = sample
+    image = Image.open(image_path)
+    size = image_processor.transforms[0].size
+    image = image.resize((size, size))
+    if dataset == "pisc_relation_split":
+        boxA = data[0]
+        boxB = data[1]
+        relation = data[2]
+        text, added_bbox = generate_pisc_sample(boxA, boxB, relation)
+        # import cv2
+        # boxA *= size
+        # boxB *= size
+        # open_cv_image = np.array(image)
+        # open_cv_image = open_cv_image[:, :, ::-1].copy() 
+        # open_cv_image = cv2.rectangle(open_cv_image, boxA[:2].astype(int), boxA[2:].astype(int), (255, 0, 0), 2)
+        # open_cv_image = cv2.rectangle(open_cv_image, boxB[:2].astype(int), boxB[2:].astype(int), (0, 255, 0), 2)
+        # cv2.imwrite("output.jpg", open_cv_image)
+        # import pdb; pdb.set_trace()
+    elif dataset == "vg_relation":
+        boxA = data[0][0]
+        nameA = data[0][1]
+        boxB = data[1][0]
+        nameB = data[1][1]
+        relation = data[2]
+        text, added_bbox = generate_vg_relation_sample(boxA, boxB, nameA, nameB, relation)
+    image = preprocess_image(image, image_processor=image_processor)
+    caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + text + tokenizer.eos_token
+    input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=True)
+    # return image, input_ids, attention_mask, added_bbox
+    images = image.unsqueeze(0)
+    image_start_index_list = [2]
+    return images, len(images), image_start_index_list, input_ids, attention_mask, added_bbox
+
+
+def preprocess_caption(sample, image_processor, tokenizer, image_embedding_size, max_length, single=False):
+    image, caption = sample
+    caption = f"<|#image#|>{tokenizer.pad_token*image_embedding_size}<|#endofimage#|>" + caption
+    image = preprocess_image(image, image_processor=image_processor)
+    input_ids, attention_mask = preprocess_text(caption, tokenizer, max_length=max_length, single=single)
+    return image, input_ids, attention_mask
+
+
+def get_pile_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
+    input_shards = args.pile_shards
+    assert input_shards is not None
+    resampled = getattr(args, "dataset_resampled", False)
+    assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
+
+    # create a shared epoch store to sync epoch to dataloader worker proc
+    shared_epoch = SharedEpoch(epoch=epoch)
+    preprocess_text_fn = functools.partial(preprocess_encoded_text, tokenizer=tokenizer, max_length=args.max_length)
+    pipeline = [
+        ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
+        tarfile_to_samples_nothrow,
+        wds.shuffle(
+            bufsize=_SAMPLE_SHUFFLE_SIZE,
+            initial=_SAMPLE_SHUFFLE_INITIAL,
+        ),
+        wds.to_tuple("txt", handler=log_and_continue),
+        wds.map_tuple(
+            preprocess_text_fn, handler=log_and_continue
+        ),
+    ]
+    # with_epoch(sys.maxsize) will give us an infinite sample stream
+    dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
+    delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
+    dataset = ConcatDataset(iter(dataset), max_length=args.max_length, delimiter_id=delimiter_id)
+
+
+    def text_collate_fn(items):
+        try:
+            input_ids = torch.cat([x[0].unsqueeze(0) for x in items], dim=0)
+            attention_mask = torch.cat([x[1].unsqueeze(0) for x in items], dim=0)
+            return input_ids, attention_mask
+        except:
+            return None, None
+
+    dataloader = wds.WebLoader(
+        dataset,
+        batch_size=args.batch_size_pile,
+        shuffle=False,
+        num_workers=args.workers,
+        persistent_workers=False,
+        collate_fn=text_collate_fn,
+    )
+    return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
+
+
+# FIXME:
+# modify /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/webdataset/filters.py, line 433
+# combine_tensors=True to combine_tensors=False
+def get_ground_laion_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
+    input_shards = args.laion_shards
+    assert input_shards is not None
+    resampled = getattr(args, "dataset_resampled", False)
+    assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
+    # create a shared epoch store to sync epoch to dataloader worker proc
+    shared_epoch = SharedEpoch(epoch=epoch)
+    generator = caption_grounder(
+        config_file="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
+        checkpoint_path="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/GroundingDINO/checkpoints/groundingdino_swint_ogc.pth",
+        cpu_only=True,
+        # box_threshold=0.5, text_threshold=0.3,
+    )
+    preprocess_ground_caption_fn = functools.partial(
+        preprocess_ground_caption, image_processor=image_processor, tokenizer=tokenizer,
+        image_embedding_size=args.vis_embed_size, single=args.single, generator=generator,
+        prob_ground=args.prob_ground, use_format_v2=args.use_format_v2,
+        add_visual_token=args.add_visual_token, max_length=args.max_length,
+        args=args,
+    )
+    pipeline = [
+        ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
+        tarfile_to_samples_nothrow,
+        wds.shuffle(
+            bufsize=_SAMPLE_SHUFFLE_SIZE,
+            initial=_SAMPLE_SHUFFLE_INITIAL,
+        ),
+        wds.select(filter_no_caption_or_no_image),
+        wds.decode("pilrgb", partial=True, handler=log_and_continue),
+        wds.to_tuple("jpg;png;jpeg", "txt", "logits.pyd", "boxes.pyd", "relation_data", handler=log_and_continue),
+        wds.map(
+            preprocess_ground_caption_fn, handler=log_and_continue
+        ),
+    ]
+
+    dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
+    # for sample in dataset:
+    #     print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
+    # DEBUG
+    # dataset = wds.DataPipeline(*pipeline)
+    # from tqdm import tqdm
+    # for sample in tqdm(dataset):
+    #     nn = 0
+    #     for x in sample[1][0]:
+    #         if x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]:
+    #             nn += 1
+    #         if x == tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]:
+    #             nn -= 1
+    #         if nn not in [0, 1]:
+    #             print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
+    #             import pdb; pdb.set_trace()
+    #     if nn != 0:
+    #         print(tokenizer.decode(sample[1][0]).replace("<PAD>", ""))
+    #         import pdb; pdb.set_trace()
+    # from groundingdino.demo.inference_on_laion import OBJ_LENGTHS
+    # # import pdb; pdb.set_trace()
+    # print(sum(OBJ_LENGTHS) / len(OBJ_LENGTHS))
+    # exit()
+    # DEBUG
+
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    box_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    dataset = ConcatDataset(
+        iter(dataset), max_length=args.max_length,
+        delimiter_id=delimiter_id,
+        pad_id=tokenizer.pad_token_id,
+        media_id=media_token_id,
+        endofmedia_id=endofmedia_token_id,
+        box_id=box_id,
+        visual_id=visual_id,
+        image_embedding_size=args.vis_embed_size,
+        single=args.single,
+    )
+
+    def image_collate_fn(items):
+        images = torch.cat([x[0] for x in items], dim=0)
+        image_nums = [x[1] for x in items]
+        image_start_index_list = [x[2] for x in items]
+        input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
+        attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
+        added_bbox_list = [x[5] for x in items]
+        expand_list = added_bbox_list[0]
+        for x in added_bbox_list[1:]:
+            expand_list.extend(x)
+        relations_list = [x[6] for x in items]
+        return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list, relations_list
+
+    dataloader = wds.WebLoader(
+        dataset,
+        batch_size=args.batch_size_laion,
+        shuffle=False,
+        num_workers=args.workers,
+        persistent_workers=False,
+        collate_fn=image_collate_fn,
+    )
+    round_fn = math.floor if floor else math.ceil
+    global_batch_size = args.batch_size_laion * args.world_size
+    num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
+    dataloader.num_batches = num_batches
+    return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
+
+
+def get_image_text_pair_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
+    input_shards = args.laion_shards
+    assert input_shards is not None
+    resampled = getattr(args, "dataset_resampled", False)
+    assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
+    # create a shared epoch store to sync epoch to dataloader worker proc
+    shared_epoch = SharedEpoch(epoch=epoch)
+    preprocess_caption_fn = functools.partial(
+        preprocess_caption, image_processor=image_processor, tokenizer=tokenizer,
+        image_embedding_size=args.vis_embed_size, single=args.single,
+        max_length=args.max_length,
+    )
+    pipeline = [
+        ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
+        tarfile_to_samples_nothrow,
+        wds.shuffle(
+            bufsize=_SAMPLE_SHUFFLE_SIZE,
+            initial=_SAMPLE_SHUFFLE_INITIAL,
+        ),
+        wds.select(filter_no_caption_or_no_image),
+        wds.decode("pilrgb", handler=log_and_continue),
+        wds.to_tuple("jpg;png;jpeg", "txt", handler=log_and_continue),
+        wds.map(
+            preprocess_caption_fn, handler=log_and_continue
+        ),
+    ]
+
+    dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    delimiter_id = tokenizer(tokenizer.eos_token, add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    dataset = ConcatDataset(
+        iter(dataset), max_length=args.max_length,
+        delimiter_id=delimiter_id,
+        pad_id=tokenizer.pad_token_id,
+        media_id=media_token_id,
+        endofmedia_id=endofmedia_token_id,
+        image_embedding_size=args.vis_embed_size,
+        single=args.single,
+    )
+
+    def image_collate_fn(items):
+        images = torch.cat([x[0] for x in items], dim=0)
+        image_nums = [x[1] for x in items]
+        image_start_index_list = [x[2] for x in items]
+        input_ids = torch.cat([x[3].unsqueeze(0) for x in items], dim=0)
+        attention_mask = torch.cat([x[4].unsqueeze(0) for x in items], dim=0)
+        return images, image_nums, image_start_index_list, input_ids, attention_mask
+
+    dataloader = wds.WebLoader(
+        dataset,
+        batch_size=args.batch_size_laion,
+        shuffle=False,
+        num_workers=args.workers,
+        persistent_workers=False,
+        collate_fn=image_collate_fn,
+    )
+    round_fn = math.floor if floor else math.ceil
+    global_batch_size = args.batch_size_laion * args.world_size
+    num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
+    dataloader.num_batches = num_batches
+    return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
+
+
+def get_instruct_dataset(args, image_processor, tokenizer, epoch=0, floor=False):
+    input_shards = args.laion_shards
+    assert input_shards is not None
+    resampled = getattr(args, "dataset_resampled", False)
+    assert resampled, "turn on dataset_resampled to allow infinite stream of samples"
+    # create a shared epoch store to sync epoch to dataloader worker proc
+    shared_epoch = SharedEpoch(epoch=epoch)
+    preprocess_instruct_fn = functools.partial(
+        preprocess_instruct, image_processor=image_processor, tokenizer=tokenizer,
+        image_embedding_size=args.vis_embed_size,
+        max_length=args.max_length,
+    )
+    pipeline = [
+        ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch),
+        tarfile_to_samples_nothrow,
+        wds.shuffle(
+            bufsize=_SAMPLE_SHUFFLE_SIZE,
+            initial=_SAMPLE_SHUFFLE_INITIAL,
+        ),
+        wds.decode(partial=True),
+        wds.to_tuple("image_path.txt", "dataset.txt", "data.pyd", handler=log_and_continue),
+        wds.map(
+            preprocess_instruct_fn, handler=log_and_continue
+        ),
+    ]
+    dataset = wds.DataPipeline(*pipeline).with_epoch(sys.maxsize)
+
+    def image_collate_fn(items):
+        images = torch.cat([x[0] for x in items], dim=0)
+        image_nums = [x[1] for x in items]
+        image_start_index_list = [x[2] for x in items]
+        input_ids = torch.cat([x[3] for x in items], dim=0)
+        attention_mask = torch.cat([x[4] for x in items], dim=0)
+        added_bbox_list = [x[5] for x in items]
+        expand_list = added_bbox_list[0]
+        for x in added_bbox_list[1:]:
+            expand_list.extend(x)
+        return images, image_nums, image_start_index_list, input_ids, attention_mask, expand_list
+
+    dataloader = wds.WebLoader(
+        dataset,
+        batch_size=args.batch_size_laion,
+        shuffle=False,
+        num_workers=args.workers,
+        persistent_workers=False,
+        collate_fn=image_collate_fn,
+    )
+    round_fn = math.floor if floor else math.ceil
+    global_batch_size = args.batch_size_laion * args.world_size
+    num_batches = round_fn(LAION2B_NUM_SAMPLE / global_batch_size)
+    dataloader.num_batches = num_batches
+    return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
+
+
+def get_dataset_fn(dataset_type):
+    if dataset_type == "mmc4":
+        raise NotImplementedError
+    elif dataset_type == "pile":
+        return get_pile_dataset
+    elif dataset_type == "ground_image_text":
+        return get_ground_laion_dataset
+    elif dataset_type == "image_text":
+        return get_image_text_pair_dataset
+    elif dataset_type == "vqav2":
+        raise NotImplementedError
+    elif dataset_type == "instruct":
+        return get_instruct_dataset
+    else:
+        raise ValueError(f"Unsupported dataset type: {dataset_type}")
+
+
+def get_data(args, image_processor, tokenizer, dataset_type, epoch=0):
+    return get_dataset_fn(dataset_type)(
+        args, image_processor=image_processor, epoch=epoch, tokenizer=tokenizer
+    )
diff --git a/multimodal/open_flamingo/train/distributed.py b/multimodal/open_flamingo/train/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..3938d063d52218eefdfde83f850999b022538c47
--- /dev/null
+++ b/multimodal/open_flamingo/train/distributed.py
@@ -0,0 +1,128 @@
+import os
+
+import torch
+
+try:
+    import horovod.torch as hvd
+except ImportError:
+    hvd = None
+
+
+def is_global_master(args):
+    return args.rank == 0
+
+
+def is_local_master(args):
+    return args.local_rank == 0
+
+
+def is_master(args, local=False):
+    return is_local_master(args) if local else is_global_master(args)
+
+
+def is_using_horovod():
+    # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
+    # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
+    ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
+    pmi_vars = ["PMI_RANK", "PMI_SIZE"]
+    if all([var in os.environ for var in ompi_vars]) or all(
+        [var in os.environ for var in pmi_vars]
+    ):
+        return True
+    else:
+        return False
+
+
+def is_using_distributed():
+    if "WORLD_SIZE" in os.environ:
+        return int(os.environ["WORLD_SIZE"]) > 1
+    if "SLURM_NTASKS" in os.environ:
+        return int(os.environ["SLURM_NTASKS"]) > 1
+    return False
+
+
+def world_info_from_env():
+    local_rank = 0
+    for v in (
+        "LOCAL_RANK",
+        "MPI_LOCALRANKID",
+        "SLURM_LOCALID",
+        "OMPI_COMM_WORLD_LOCAL_RANK",
+    ):
+        if v in os.environ:
+            local_rank = int(os.environ[v])
+            break
+    global_rank = 0
+    for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
+        if v in os.environ:
+            global_rank = int(os.environ[v])
+            break
+    world_size = 1
+    for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
+        if v in os.environ:
+            world_size = int(os.environ[v])
+            break
+
+    return local_rank, global_rank, world_size
+
+
+def init_distributed_device(args):
+    # Distributed training = training on more than one GPU.
+    # Works in both single and multi-node scenarios.
+    args.distributed = False
+    args.world_size = 1
+    args.rank = 0  # global rank
+    args.local_rank = 0
+    if args.horovod:
+        assert hvd is not None, "Horovod is not installed"
+        hvd.init()
+        args.local_rank = int(hvd.local_rank())
+        args.rank = hvd.rank()
+        args.world_size = hvd.size()
+        args.distributed = True
+        os.environ["LOCAL_RANK"] = str(args.local_rank)
+        os.environ["RANK"] = str(args.rank)
+        os.environ["WORLD_SIZE"] = str(args.world_size)
+    elif is_using_distributed():
+        if "SLURM_PROCID" in os.environ:
+            # DDP via SLURM
+            args.local_rank, args.rank, args.world_size = world_info_from_env()
+            # SLURM var -> torch.distributed vars in case needed
+            os.environ["LOCAL_RANK"] = str(args.local_rank)
+            os.environ["RANK"] = str(args.rank)
+            os.environ["WORLD_SIZE"] = str(args.world_size)
+            torch.distributed.init_process_group(
+                backend=args.dist_backend,
+                init_method=args.dist_url,
+                world_size=args.world_size,
+                rank=args.rank,
+            )
+        else:
+            # DDP via torchrun, torch.distributed.launch
+            args.local_rank, _, _ = world_info_from_env()
+            torch.distributed.init_process_group(
+                backend=args.dist_backend, init_method=args.dist_url
+            )
+            args.world_size = torch.distributed.get_world_size()
+            args.rank = torch.distributed.get_rank()
+        args.distributed = True
+    else:
+        # needed to run on single gpu
+        torch.distributed.init_process_group(
+            backend=args.dist_backend,
+            init_method=args.dist_url,
+            world_size=1,
+            rank=0,
+        )
+
+    if torch.cuda.is_available():
+        if args.distributed and not args.no_set_device_rank:
+            device = "cuda:%d" % args.local_rank
+        else:
+            device = "cuda:0"
+        torch.cuda.set_device(device)
+    else:
+        device = "cpu"
+    args.device = device
+    device = torch.device(device)
+    return device
diff --git a/multimodal/open_flamingo/train/instruction_template.py b/multimodal/open_flamingo/train/instruction_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b449fd79a1d97241c33f0ea0d9eace91b63466d
--- /dev/null
+++ b/multimodal/open_flamingo/train/instruction_template.py
@@ -0,0 +1,13 @@
+VG_RELATION_TEMPLATES = [
+    "Question: What is the relationship between<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
+    "Question: What is the relationship between<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
+    "Question: What {is_or_does}<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {relation_do}? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|>{nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
+    "Question: What {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
+    "Question: What {is_or_does}<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> {relation_do}? Answer:<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
+    "Question: What {use_is} {relation}<|#object#|> {nameB}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer:<|#object#|> {nameA}<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>.",
+]
+
+PISC_TEMPLATES = [
+    "Question: What is the social relationship between this<|#object#|> person<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|> and that<|#object#|> person<|#endofobject#|><|#visual#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
+    "Question: What is the social relationship between these<|#object#|> people<|#endofobject#|><|#visual#|><|#box#|><|#box#|><|#endofobject#|>? Answer: {relation}.",
+]
diff --git a/multimodal/open_flamingo/train/train.py b/multimodal/open_flamingo/train/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e77a07909ec09e691b173957c0b70eaf261b4f4
--- /dev/null
+++ b/multimodal/open_flamingo/train/train.py
@@ -0,0 +1,709 @@
+""" Main training script """
+
+import argparse
+import copy
+import glob
+import os
+import random
+import functools
+
+import numpy as np
+import torch
+# torch.multiprocessing.set_sharing_strategy('file_system')
+import wandb
+from data2 import get_data
+from distributed import init_distributed_device, world_info_from_env
+from torch.distributed.fsdp import (
+    FullyShardedDataParallel as FSDP,
+    MixedPrecision,
+    BackwardPrefetch,
+    ShardingStrategy,
+    FullStateDictConfig,
+    CPUOffload,
+    StateDictType,
+)
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import (
+    transformer_auto_wrap_policy,
+    enable_wrap,
+    wrap,
+)
+
+from train_utils import train_one_epoch
+from transformers import (
+    get_constant_schedule_with_warmup,
+    get_cosine_schedule_with_warmup,
+    get_linear_schedule_with_warmup,
+)
+
+from open_flamingo import create_model_and_transforms
+from torch.utils.tensorboard import SummaryWriter
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.cuda.amp import GradScaler
+from torch.distributed.optim import ZeroRedundancyOptimizer
+import warnings
+warnings.filterwarnings("ignore")
+import logging
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s %(message)s',
+    datefmt='%m/%d %I:%M:%S',
+)
+
+class FakeDataloader:
+    def __iter__(self):
+        return self
+    
+    def __next__(self):
+        return None
+
+def random_seed(seed=42, rank=0):
+    torch.manual_seed(seed + rank)
+    np.random.seed(seed + rank)
+    random.seed(seed + rank)
+
+
+def get_grouped_params(model, args):
+    params_with_wd, params_without_wd = [], []
+
+    def apply_decay(x):
+        x = x.lower()
+        return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x
+
+    for n, p in model.named_parameters():
+        # if p.requires_grad:
+        if apply_decay(n):
+            if torch.distributed.get_rank() == 0:
+                logging.info(f"with wd: {n}")
+            params_with_wd.append(p)
+        else:
+            if torch.distributed.get_rank() == 0:
+                logging.info(f"without wd: {n}")
+            params_without_wd.append(p)
+    return [
+        {"params": params_with_wd, "weight_decay": args.weight_decay},
+        {"params": params_without_wd, "weight_decay": 0.0},
+    ]
+
+
+def lambda_policy_fn(module):
+    if (
+        len(list(module.named_children())) == 0
+        and getattr(module, "weight", None) is not None
+        and module.weight.requires_grad
+    ):
+        return True
+    return False
+
+
+def lambda_auto_wrap_policy(
+    module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn,
+) -> bool:
+    """
+    A convenient auto wrap policy to wrap submodules based on an arbitrary user
+    function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
+    a `wrapper_cls` unit.
+
+    Return if a module should be wrapped during auto wrapping.
+
+    The first three parameters are required by :func:`_recursive_wrap`.
+
+    Args:
+        module (nn.Module): Current module being considered.
+        recurse (bool): If ``False``, then this function must decide whether
+            ``module`` should be wrapped as an FSDP instance or not. If
+            ``True``, then the function is still recursing down the module
+            tree as a part of the DFS.
+        nonwrapped_numel (int): Parameter numel not yet wrapped.
+
+        lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
+            this module will be wrapped.
+    """
+    if recurse:
+        return True  # always recurse
+    return lambda_fn(module)
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str)
+    parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str)
+    parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
+    parser.add_argument(
+        "--tokenizer_path",
+        default="facebook/opt-1.3b",
+        type=str,
+        help="path to tokenizer",
+    )
+    parser.add_argument(
+        "--run_name",
+        type=str,
+        default="openflamingo3B",
+        help="used to name saving directory and wandb run",
+    )
+    parser.add_argument("--use_media_placement_augmentation", action="store_true")
+    parser.add_argument("--offline", action="store_true")
+    parser.add_argument("--num_steps", type=int, default=300000)
+    parser.add_argument(
+        "--logging_steps", type=int, default=10, help="log loss every n steps"
+    )
+    # Sum of gradient optimization batch size
+    parser.add_argument("--batch_size_mmc4", type=int, default=128)
+    parser.add_argument("--batch_size_laion", type=int, default=128)
+    parser.add_argument("--batch_size_pile", type=int, default=128)
+    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
+    parser.add_argument(
+        "--resume_from_checkpoint",
+        type=str,
+        help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states",
+        default=None,
+    )
+    parser.add_argument(
+        "--delete_previous_checkpoint",
+        action="store_true",
+        help="delete previous checkpoint when saving new checkpoint",
+    )
+    parser.add_argument(
+        "--laion_shards",
+        type=str,
+        help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
+    )
+    parser.add_argument(
+        "--mmc4_shards",
+        type=str,
+        help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
+    )
+    parser.add_argument(
+        "--pile_shards",
+        type=str,
+        default=None,
+        help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
+    )
+    parser.add_argument("--seed", type=int, default=42)
+    parser.add_argument("--learning_rate", default=1e-4, type=float)
+    parser.add_argument(
+        "--lr_scheduler",
+        default="constant",
+        type=str,
+        help="constant, linear, or cosine",
+    )
+    parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0)
+    parser.add_argument("--loss_multiplier_laion", type=float, default=1.0)
+    parser.add_argument("--loss_multiplier_pile", type=float, default=1.0)
+    parser.add_argument("--loss_multiplier_det", type=float, default=1.0)
+    parser.add_argument("--loss_multiplier_rel", type=float, default=1.0)
+    parser.add_argument("--loss_multiplier_attn", type=float, default=1.0)
+    parser.add_argument("--warmup_steps", default=5000, type=int)
+    # weight decay is only apply to YOLOX head if using FSDP
+    # https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159
+    parser.add_argument("--weight_decay", default=0.05, type=float)
+    parser.add_argument(
+        "--precision",
+        choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
+        default="fp32",
+        help="Floating point precision.",
+    )
+    # data args
+    parser.add_argument("--workers", type=int, default=1)
+    parser.add_argument("--dataset_resampled", action="store_true")
+    # distributed training args
+    parser.add_argument(
+        "--dist-url",
+        default="env://",
+        type=str,
+        help="url used to set up distributed training",
+    )
+    parser.add_argument(
+        "--dist-backend", default="nccl", type=str, help="distributed backend"
+    )
+    parser.add_argument(
+        "--horovod",
+        default=False,
+        action="store_true",
+        help="Use horovod for distributed training.",
+    )
+    parser.add_argument(
+        "--no-set-device-rank",
+        default=False,
+        action="store_true",
+        help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
+    )
+    # wandb args
+    parser.add_argument("--report_to_wandb", default=False, action="store_true")
+    parser.add_argument(
+        "--wandb_project",
+        type=str,
+    )
+    parser.add_argument(
+        "--wandb_entity",
+        type=str,
+    )
+    parser.add_argument(
+        "--save_checkpoints_to_wandb",
+        default=False,
+        action="store_true",
+        help="save checkpoints to wandb",
+    )
+    parser.add_argument(
+        "--checkpoint_activations",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--freeze_vision_encoder",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--mmc4_textsim_threshold",
+        default=30,
+        type=float,
+        help="threshold for filtering images in mmc4 based on image-text similarity",
+    )
+    parser.add_argument(
+        "--location_token_num",
+        default=1000,
+        type=int,
+    )
+    parser.add_argument(
+        "--vis_embed_size",
+        type=int,
+        required=False,
+    )
+    parser.add_argument(
+        "--save_interval",
+        default=1000,
+        type=int,
+        required=False,
+    )
+    parser.add_argument(
+        "--skip_delete_pattern",
+        default=1500,
+        type=int,
+        required=False,
+    )
+    parser.add_argument(
+        "--ddp",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--pile_freq",
+        default=1,
+        type=int,
+        required=False,
+    )
+    parser.add_argument(
+        "--restart",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--lora",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--lora_r",
+        default=16,
+        type=int,
+        required=False,
+    )
+    parser.add_argument(
+        "--single",
+        default=False,
+        action="store_true",
+    )
+
+    # Finetune
+    parser.add_argument(
+        "--instruct",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--fix-ffn",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--prob_ground",
+        default=1.0,
+        type=float,
+        required=False,
+    )
+    parser.add_argument(
+        "--optimizer",
+        default="adamw",
+        type=str,
+        required=False,
+    )
+    parser.add_argument(
+        "--add_visual_token",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--use_format_v2",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--use_sam",
+        default=None,
+        type=str,
+        required=False,
+    )
+    parser.add_argument(
+        "--max-length",
+        default=608,
+        type=int,
+        required=False,
+    )
+    parser.add_argument(
+        "--image-size",
+        default=256,
+        type=int,
+        required=False,
+    )
+    parser.add_argument(
+        "--reset_llm",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--add_box",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--add_pe",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--only_grounded_sample",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--expand",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--delete_contained",
+        default=False,
+        action="store_true",
+    )
+
+    parser.add_argument(
+        "--relation",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--attn_reg",
+        default="l1",
+        type=str,
+        required=False,
+    )
+    parser.add_argument(
+        "--enhance_data",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--no_visual",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--no_previsual",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--roi_align",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--roi_output_size",
+        default=4,
+        type=int,
+        required=False,
+    )
+    parser.add_argument(
+        "--apply_mask",
+        default=False,
+        action="store_true",
+    )
+    parser.add_argument(
+        "--longer_previsual",
+        default=False,
+        action="store_true",
+    )
+
+    args = parser.parse_args()
+    assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation"
+    if args.no_previsual:
+        assert args.no_visual, "no_previsual MUST come with no_visual"
+    assert not args.enhance_data, "dont enable enhance_data"
+
+    if args.offline:
+        os.environ["WANDB_MODE"] = "offline"
+        os.environ["TRANSFORMERS_OFFLINE"] = "1"
+
+    args.local_rank, args.rank, args.world_size = world_info_from_env()
+    print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}")
+    device_id = init_distributed_device(args)
+
+    random_seed(args.seed)
+    model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms(
+        args.vision_encoder_path,
+        args.vision_encoder_pretrained,
+        args.lm_path,
+        args.tokenizer_path if args.tokenizer_path else args.lm_path,
+        use_local_files=args.offline,
+        use_media_placement_augmentation=args.use_media_placement_augmentation,
+        checkpoint_activations=args.checkpoint_activations,
+        freeze_vision_encoder=args.freeze_vision_encoder,
+        location_token_num=args.location_token_num,
+        lora=args.lora,
+        lora_r=args.lora_r,
+        fix_ffn=args.fix_ffn,
+        add_visual_token=args.add_visual_token,
+        add_box=args.add_box,
+        add_pe=args.add_pe,
+        add_relation=args.relation,
+        use_format_v2=args.use_format_v2,
+        use_sam=args.use_sam,
+        enhance_data=args.enhance_data,
+        roi_align=args.roi_align,
+        roi_output_size=args.roi_output_size,
+        apply_mask=args.apply_mask,
+    )
+    if args.reset_llm:
+        llm_state_dict = model.lang_encoder.state_dict()
+    if args.rank == 0:
+        print(args)
+        print(image_processor)
+
+    random_seed(args.seed, args.rank)
+
+    if args.rank == 0 and args.report_to_wandb:
+        wandb.init(
+            project=args.wandb_project,
+            entity=args.wandb_entity,
+            name=args.run_name,
+            config=vars(args),
+        )
+
+    device_id = args.rank % torch.cuda.device_count()
+    if args.ddp:
+        print("use ddp mode")
+        model = model.to(device_id)
+        model = DDP(model)
+    else:
+        fpSixteen = MixedPrecision(
+            param_dtype=torch.float16,
+            # Gradient communication precision.
+            reduce_dtype=torch.float16,
+            # Buffer precision.
+            # buffer_dtype=torch.float16,
+        )
+        # from transformers.models.opt.modeling_opt import OPTDecoderLayer
+        from open_clip.transformer import ResidualAttentionBlock
+        from open_flamingo.src.flamingo_lm import FlamingoLayer
+        from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention
+        from segment_anything.modeling.image_encoder import Block
+        transformer_layer_cls=[
+            FlamingoLayer,
+            ResidualAttentionBlock,
+            Block,
+        ]
+        if args.fix_ffn:
+            transformer_layer_cls.append(OPTAttention)
+        auto_wrap_policy = functools.partial(
+            transformer_auto_wrap_policy,
+            transformer_layer_cls=transformer_layer_cls,
+        )
+        if args.lora:
+            from torch.distributed.fsdp.wrap import _or_policy
+            lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
+            auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy])
+            ignored_modules = [model.vision_encoder]
+            # ignored_modules = None
+        else:
+            ignored_modules = [model.detection_head]
+            # ignored_modules = None
+        if args.add_pe:
+            ignored_modules += [model.pos_enc]
+        # if args.use_format_v2:
+        #     ignored_modules += [model.lang_encoder.visual_guided_lm_head]
+        model = FSDP(
+            model,
+            auto_wrap_policy=auto_wrap_policy,
+            mixed_precision=fpSixteen,
+            device_id=torch.cuda.current_device(),
+            ignored_modules=ignored_modules,
+            sharding_strategy=ShardingStrategy.SHARD_GRAD_OP,
+        )
+        model = model.to(device_id)
+
+
+    pile_dataset = None
+    if args.instruct:
+        laion_dataset = get_data(args, image_processor, tokenizer, "instruct")
+    else:
+        laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text")
+    if args.pile_shards is not None:
+        pile_dataset = get_data(args, image_processor, tokenizer, "pile")
+
+
+    optim_groups = get_grouped_params(model, args)
+    # optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
+    if args.ddp:
+        optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
+        # optimizer = ZeroRedundancyOptimizer(
+        #     optim_groups,
+        #     optimizer_class=torch.optim.AdamW,
+        #     lr=args.learning_rate,
+        #     parameters_as_bucket_view=True,
+        # )
+    else:
+        if args.optimizer == "adamw":
+            print("use adamw")
+            optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate)
+        elif args.optimizer == "sgd":
+            print("use sgd...")
+            optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
+        else:
+            raise NotImplementedError
+
+    total_training_steps = args.num_steps
+
+    if args.rank == 0:
+        logging.info(f"Total training steps: {total_training_steps}")
+
+    if args.lr_scheduler == "linear":
+        lr_scheduler = get_linear_schedule_with_warmup(
+            optimizer,
+            num_warmup_steps=args.warmup_steps,
+            num_training_steps=total_training_steps,
+        )
+    elif args.lr_scheduler == "cosine":
+        lr_scheduler = get_cosine_schedule_with_warmup(
+            optimizer,
+            num_warmup_steps=args.warmup_steps,
+            num_training_steps=total_training_steps,
+        )
+    else:
+        lr_scheduler = get_constant_schedule_with_warmup(
+            optimizer, num_warmup_steps=args.warmup_steps
+        )
+    if args.ddp:
+        scaler = GradScaler()
+    else:
+        scaler = ShardedGradScaler()
+    total_laion_token = 0
+    total_pile_token = 0
+    total_laion_sample = 0
+    total_step = 0
+
+    # check if a checkpoint exists for this run
+    if os.path.exists(f"{args.run_name}"):
+        checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
+        if len(checkpoint_list) == 0:
+            if args.rank == 0:
+                logging.info(f"Found no checkpoints for run {args.run_name}.")
+        else:
+            args.resume_from_checkpoint = sorted(
+                checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
+            )[-1]
+            if args.rank == 0:
+                logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.")
+            args.restart = False
+            if args.rank == 0:
+                logging.info("do not restart because an existed checkpoint is found")
+    if args.resume_from_checkpoint is not None:
+        if args.rank == 0:
+            logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}")
+        checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
+        torch.distributed.barrier()
+        if args.ddp:
+            model.module.load_state_dict(checkpoint["model_state_dict"], strict=False)
+            # sharded_osd = checkpoint['optimizer_state_dict']
+        else:
+            with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
+                if args.reset_llm:
+                    for key in checkpoint["model_state_dict"]:
+                        if key.startswith("lang_encoder"):
+                            if args.rank == 0:
+                                logging.info(f"reset {key}")
+                            llm_key = key.replace("lang_encoder.", "")
+                            checkpoint["model_state_dict"][key] = llm_state_dict[llm_key]
+                model_state_dict = model.state_dict()
+                for key in checkpoint["model_state_dict"].keys():
+                    if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape:
+                        if args.rank == 0:
+                            logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}')
+                        checkpoint["model_state_dict"][key] = model_state_dict[key].clone()
+                del model_state_dict
+                model.load_state_dict(checkpoint["model_state_dict"], False)
+            # sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups)
+        if not args.restart:
+            # optimizer.load_state_dict(sharded_osd)
+            lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
+            # scaler.load_state_dict(checkpoint["scaler_state_dict"])
+            total_laion_token = checkpoint.get("total_laion_token", 0)
+            total_pile_token = checkpoint.get("total_pile_token", 0)
+            total_laion_sample = checkpoint.get("total_laion_sample", 0)
+            total_step = checkpoint.get("total_step", 0)
+            if args.rank == 0:
+                logging.info("load training statistics...")
+        else:
+            if args.rank == 0:
+                logging.info("restart training / finetuning. only load model weight...")
+        del checkpoint
+        if args.reset_llm:
+            del llm_state_dict
+        torch.cuda.empty_cache()
+        torch.distributed.barrier()
+
+    model.train()
+    if args.rank == 0:
+        if not os.path.exists(args.run_name):
+            os.makedirs(args.run_name)
+        writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog"))
+    else:
+        writer = None
+
+    laion_dataset.set_epoch(total_step)
+    laion_loader = laion_dataset.dataloader
+    if pile_dataset is not None:
+        pile_dataset.set_epoch(total_step)
+        pile_loader = pile_dataset.dataloader
+    else:
+        pile_loader = FakeDataloader()
+    train_one_epoch(
+        args=args,
+        model=model,
+        tokenizer=tokenizer,
+        optimizer=optimizer,
+        lr_scheduler=lr_scheduler,
+        laion_loader=laion_loader,
+        pile_loader=pile_loader,
+        device_id=device_id,
+        writer=writer,
+        scaler=scaler,
+        optim_groups=optim_groups,
+        total_laion_token=total_laion_token,
+        total_pile_token=total_pile_token,
+        total_laion_sample=total_laion_sample,
+        total_step=total_step,
+    )
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/open_flamingo/train/train_utils.py b/multimodal/open_flamingo/train/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ccdaf049ba5092933a6c01bc28019cb55174b30
--- /dev/null
+++ b/multimodal/open_flamingo/train/train_utils.py
@@ -0,0 +1,387 @@
+import time
+from contextlib import suppress
+import numpy as np
+
+import torch
+from tqdm import tqdm
+import datetime
+import os
+import gc
+from torch.distributed.fsdp import (
+    FullyShardedDataParallel as FSDP,
+    MixedPrecision,
+    BackwardPrefetch,
+    ShardingStrategy,
+    FullStateDictConfig,
+    StateDictType,
+)
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp.wrap import (
+    transformer_auto_wrap_policy,
+    enable_wrap,
+    wrap,
+)
+
+from torch.utils.tensorboard import SummaryWriter
+import logging
+logging.basicConfig(
+    level=logging.INFO,
+    format='%(asctime)s %(message)s',
+    datefmt='%m/%d %I:%M:%S',
+)
+
+def get_cast_dtype(precision: str):
+    cast_dtype = None
+    if precision == "bf16":
+        cast_dtype = torch.bfloat16
+    elif precision == "fp16":
+        cast_dtype = torch.float16
+    return cast_dtype
+
+
+def get_autocast(precision):
+    if precision == "amp_fp16":
+        return lambda: torch.cuda.amp.autocast(dtype=torch.float16)
+    elif precision == "amp_bfloat16" or precision == "amp_bf16":
+        # amp_bfloat16 is more stable than amp float16 for clip training
+        return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
+    else:
+        return suppress
+
+
+def get_sync(model, flag):
+    if flag:
+        return suppress
+    else:
+        return lambda: model.no_sync()
+
+
+def train_one_epoch(
+    args,
+    model,
+    laion_loader,
+    pile_loader,
+    tokenizer,
+    optimizer,
+    lr_scheduler,
+    device_id,
+    writer: SummaryWriter,
+    optim_groups,
+    scaler,
+    total_laion_token: int,
+    total_pile_token: int,
+    total_laion_sample: int,
+    total_step: int,
+):
+    world_size = torch.distributed.get_world_size()
+    autocast = get_autocast(args.precision)
+    cast_dtype = get_cast_dtype(args.precision)
+
+    media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1]
+    endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1]
+    visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1]
+    if args.add_box:
+        box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1]
+        endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1]
+        endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1]
+    if args.use_format_v2:
+        prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1]
+        previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1]
+    if args.rank == 0:
+        logging.info(f"train from: {total_step} step")
+    model.train()
+    # loop through dataloader
+    last_logging_step = total_step
+    last_save_step = total_step
+    for num_steps, (batch_laion, batch_pile) in tqdm(
+        enumerate(zip(laion_loader, pile_loader)),
+        disable=args.rank != 0 or "SLURM_PROCID" in os.environ,
+        total=args.num_steps * args.gradient_accumulation_steps,
+        initial=total_step * args.gradient_accumulation_steps,
+    ):
+        #### LAION FORWARD PASS ####
+        images = (
+            batch_laion[0]
+            .to(device_id, dtype=cast_dtype, non_blocking=True)
+            .unsqueeze(1)
+            .unsqueeze(1)
+        )
+        image_nums = batch_laion[1]
+        image_start_index_list = batch_laion[2]
+
+        # TODO: OPT model: input_ids is not started with </s> while input_ids2 is?
+        input_ids = batch_laion[3].to(device_id, non_blocking=True).long()
+        attention_mask = batch_laion[4].to(device_id, dtype=cast_dtype, non_blocking=True)
+        added_bbox_list = [x.to(device_id) for x in batch_laion[5]] # list object
+        total_laion_token += int(attention_mask.sum().long()) * world_size
+        total_laion_sample += sum(image_nums) * world_size
+
+        labels = input_ids.clone()
+        if args.add_box:
+            labels[input_ids == visual_token_id] = -100
+            labels[input_ids == box_token_id] = -100
+            labels[input_ids == endofattr_token_id] = -100
+            if args.use_format_v2:
+                labels[input_ids == previsual_token_id] = -100
+                labels[input_ids == prebox_token_id] = -100
+                labels[torch.roll(input_ids == prebox_token_id, 1)] = -100
+                labels[torch.roll(input_ids == box_token_id, 1)] = -100
+        labels[:, 0] = -100
+        labels[input_ids == tokenizer.pad_token_id] = -100
+        labels[input_ids == media_token_id] = -100
+        labels[input_ids == endofmedia_token_id] = -100
+        labels.to(device_id)
+        current_laion_num = input_ids.shape[0]
+
+        #### PILE FORWARD PASS ####
+        if batch_pile is not None and batch_pile[0] is not None and batch_pile[1] is not None:
+            input_ids2 = batch_pile[0].to(device_id, non_blocking=True).long()
+            attention_mask2 = batch_pile[1].to(device_id, dtype=cast_dtype, non_blocking=True)
+            input_length = input_ids.shape[-1]
+
+            input_ids2 = torch.cat([input_ids2, torch.ones((input_ids2.shape[0], input_length - input_ids2.shape[1]), device=input_ids2.device, dtype=input_ids2.dtype) * tokenizer.pad_token_id], dim=-1)
+            attention_mask2 = torch.cat([attention_mask2, torch.zeros((attention_mask2.shape[0], input_length - attention_mask2.shape[1]), device=attention_mask2.device, dtype=attention_mask2.dtype)], dim=-1)
+
+            labels2 = input_ids2.clone()
+            labels2[labels2 == tokenizer.pad_token_id] = -100
+            labels2[:, 0] = -100
+            labels2.to(device_id)
+
+            if (num_steps != 0 and num_steps % args.pile_freq == 0) or args.pile_freq == 1:
+                image_nums = image_nums + [0] * len(input_ids2)
+                image_start_index_list = image_start_index_list + [[]] * len(input_ids2)
+                input_ids = torch.cat([input_ids, input_ids2], dim=0)
+                attention_mask = torch.cat([attention_mask, attention_mask2], dim=0)
+                labels = torch.cat([labels, labels2], dim=0)
+                total_pile_token += int(attention_mask2.sum().long()) * world_size
+            else:
+                del input_ids2
+                del attention_mask2
+                del labels2
+
+        if args.instruct:
+            answer_token_id = tokenizer(" Answer").input_ids[0]
+            answer_token_loc = (input_ids == answer_token_id).nonzero()
+            for batch_idx, idx in answer_token_loc:
+                labels[batch_idx][:idx+2] = -100
+        
+        if args.relation and not args.instruct:
+            relations = batch_laion[6]
+        else:
+            relations = None
+        if len(added_bbox_list) == 0:
+            added_bbox_list = None
+        update_flag = (num_steps != 0 and num_steps % args.gradient_accumulation_steps == 0) or args.gradient_accumulation_steps == 1
+        # do_sync = get_sync(model, update_flag)
+        with autocast():
+            # modify: 
+            #   /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py
+            #   /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py
+            # CrossEntropyLoss(reduction="none")
+            outputs = model(
+                vision_x=images,
+                lang_x=input_ids,
+                attention_mask=attention_mask,
+                labels=labels,
+                image_nums=image_nums,
+                image_start_index_list=image_start_index_list,
+                added_bbox_list=added_bbox_list,
+                add_box=args.add_box,
+                relations=relations,
+            )
+            loss_total = outputs.loss.reshape(labels.shape[0], -1)
+            loss_sample = loss_total.sum(-1) / (loss_total != 0).sum(-1)
+            loss_sample_for_laion = loss_sample[:current_laion_num]
+            nan_mask = torch.isnan(loss_sample_for_laion)
+            if nan_mask.sum() > 0:
+                logging.warning(f"caption NaN: {nan_mask}")
+            if nan_mask.sum() == len(loss_sample_for_laion) or not model.valid:
+                logging.info("WARNING: skip this caption loss due to some error")
+                loss_laion = torch.tensor(0.0).cuda()
+            else:
+                loss_laion = loss_sample_for_laion[~nan_mask].mean()
+            loss_caption = loss_laion
+            divided_loss_laion = loss_laion / args.gradient_accumulation_steps
+            if current_laion_num != loss_sample.shape[0]:
+                loss_pile = loss_sample[current_laion_num:].mean()
+            else:
+                loss_pile = torch.tensor(0.0).cuda()
+            divided_loss_pile = loss_pile / args.gradient_accumulation_steps
+
+            if "detection_losses" in outputs:
+                loss_det = outputs["detection_losses"]["loss"]
+                loss_iou = outputs["detection_losses"]["loss_iou"]
+                loss_obj = outputs["detection_losses"]["loss_obj"]
+                loss_cls = outputs["detection_losses"]["loss_cls"]
+            else:
+                loss_det = torch.tensor(0.0).cuda()
+                loss_iou = torch.tensor(0.0).cuda()
+                loss_obj = torch.tensor(0.0).cuda()
+                loss_cls = torch.tensor(0.0).cuda()
+
+            if "loss_dict" in outputs:
+                visual_loss_iou = outputs["loss_dict"][0]["loss_iou"]
+                previsual_loss_iou = outputs["loss_dict"][1]["loss_iou"]
+                visual_loss_obj = outputs["loss_dict"][0]["loss_obj"]
+                previsual_loss_obj = outputs["loss_dict"][1]["loss_obj"]
+            else:
+                visual_loss_iou = torch.tensor(0.0).cuda()
+                previsual_loss_iou = torch.tensor(0.0).cuda()
+                visual_loss_obj = torch.tensor(0.0).cuda()
+                previsual_loss_obj = torch.tensor(0.0).cuda()
+
+            divided_loss_det = loss_det / args.gradient_accumulation_steps
+            loss_rel = outputs.get("rel_loss", torch.tensor(0.0).cuda())
+            divided_loss_rel = loss_rel / args.gradient_accumulation_steps
+            loss = (
+                divided_loss_laion * args.loss_multiplier_laion +
+                divided_loss_pile * args.loss_multiplier_pile +
+                divided_loss_det * args.loss_multiplier_det +
+                divided_loss_rel * args.loss_multiplier_rel
+            )
+
+        scaler.scale(loss).backward()
+
+        # for logging only
+        loss = (
+            loss_laion * args.loss_multiplier_laion
+            + loss_pile * args.loss_multiplier_pile
+            + loss_det * args.loss_multiplier_det
+            + loss_rel * args.loss_multiplier_rel
+        ).detach()
+
+        # step optimizer and log
+        if update_flag:
+            #### MASK GRADIENTS FOR EMBEDDINGS ####
+            # Note (anas): Do not apply weight decay to embeddings as it will break this function.
+            # ! not an important point
+            # if args.ddp:
+            #     def mask_embedding(m):
+            #         if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
+            #             zero_mask = torch.zeros_like(m.weight.grad)
+            #             zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
+            #             zero_mask[endofmedia_token_id] = torch.ones_like(zero_mask[endofmedia_token_id])
+            #             m.weight.grad = m.weight.grad * zero_mask
+            #     model.apply(mask_embedding)
+            total_step += 1
+            scaler.unscale_(optimizer)
+            if args.ddp:
+                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+            else:
+                model.clip_grad_norm_(1.0)
+            scaler.step(optimizer)
+            scaler.update()
+            lr_scheduler.step()
+            optimizer.zero_grad()
+            # https://github.com/facebookresearch/fairscale/issues/627
+            model.zero_grad(set_to_none=True)
+
+        if args.rank == 0 and total_step % args.logging_steps == 0 and total_step != last_logging_step:
+            last_logging_step = total_step
+            global_step = total_step
+            lr = optimizer.param_groups[0]["lr"]
+            writer.add_scalar("lr", lr, global_step)
+            writer.add_scalar("scale", scaler.get_scale(), global_step)
+            writer.add_scalar("loss_groundcaption", loss_laion.item(), global_step)
+            writer.add_scalar("loss_laion", loss_caption.item(), global_step)
+            writer.add_scalar("loss_pile", loss_pile.item(), global_step)
+            writer.add_scalar("loss", loss.item(), global_step)
+            writer.add_scalar("loss_det", loss_det.item(), global_step)
+            writer.add_scalar("loss_iou", loss_iou.item(), global_step)
+            writer.add_scalar("loss_obj", loss_obj.item(), global_step)
+            writer.add_scalar("loss_cls", loss_cls.item(), global_step)
+            if loss_rel.item() != 0:
+                writer.add_scalar("loss_rel", loss_rel.item(), global_step)
+            if args.use_format_v2:
+                writer.add_scalar("loss_iou_visual", visual_loss_iou.item(), global_step)
+                writer.add_scalar("loss_obj_visual", visual_loss_obj.item(), global_step)
+                writer.add_scalar("loss_iou_previsual", previsual_loss_iou.item(), global_step)
+                writer.add_scalar("loss_obj_previsual", previsual_loss_obj.item(), global_step)
+
+            global_sample_num = total_laion_sample
+            writer.add_scalar("loss_groundcaption_vs_sample_num", loss_laion.item(), global_sample_num)
+            writer.add_scalar("loss_laion_vs_sample_num", loss_caption.item(), global_sample_num)
+            writer.add_scalar("loss_pile_vs_sample_num", loss_pile.item(), global_sample_num)
+            writer.add_scalar("loss_vs_sample_num", loss.item(), global_sample_num)
+            writer.add_scalar("loss_det_vs_sample_num", loss_det.item(), global_sample_num)
+            writer.add_scalar("loss_iou_vs_sample_num", loss_iou.item(), global_sample_num)
+            writer.add_scalar("loss_obj_vs_sample_num", loss_obj.item(), global_sample_num)
+            if loss_rel.item() != 0:
+                writer.add_scalar("loss_rel_vs_sample_num", loss_rel.item(), global_sample_num)
+            writer.add_scalar("lr_vs_sample_num", optimizer.param_groups[0]["lr"], global_sample_num)
+
+            writer.add_scalar("loss_groundcaption_vs_token", loss_laion.item(), total_laion_token)
+            writer.add_scalar("loss_laion_vs_token", loss_caption.item(), total_laion_token)
+            writer.add_scalar("loss_pile_vs_token", loss_pile.item(), total_pile_token)
+            writer.add_scalar("loss_det_vs_token", loss_det.item(), total_laion_token)
+            writer.add_scalar("loss_iou_vs_token", loss_iou.item(), total_laion_token)
+            writer.add_scalar("loss_obj_vs_token", loss_obj.item(), total_laion_token)
+            writer.add_scalar("loss_cls_vs_token", loss_cls.item(), total_laion_token)
+            if loss_rel.item() != 0:
+                writer.add_scalar("loss_rel_vs_token", loss_rel.item(), total_laion_token)
+
+            total_token = total_laion_token + total_pile_token
+            writer.add_scalar("sample_num", global_sample_num, global_step)
+            writer.add_scalar("total_laion_token", total_laion_token, global_step)
+            writer.add_scalar("total_pile_token", total_pile_token, global_step)
+            writer.add_scalar("total_token", total_token, global_step)
+            logging.info(
+                f"[{global_step}][{total_laion_sample}][{total_token}]. total: {loss.item():.3f} //  laion: {loss_caption.item():.3f} // pile: {loss_pile.item():.3f} // iou: {loss_iou.item():.4f} // obj: {loss_obj.item():.4f} // previsual_obj: {previsual_loss_obj.item():.4f} // visual_obj: {visual_loss_obj.item():.4f} // previsual_iou: {previsual_loss_iou.item():.4f} // visual_iou: {visual_loss_iou.item():.4f} // lr: {lr:.2e} // scale: {scaler.get_scale()}"
+            )
+
+        if total_step % args.save_interval == 0 and total_step != last_save_step:
+            last_save_step = total_step
+            torch.distributed.barrier()
+            if args.ddp:
+                cpu_state = model.state_dict()
+                # if args.rank == 0:
+                #     optimizer_state = optimizer.state_dict()
+            else:
+                save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
+                with FSDP.state_dict_type(
+                    model, StateDictType.FULL_STATE_DICT, save_policy
+                ):
+                    cpu_state = model.state_dict()
+                torch.distributed.barrier()
+                # https://pytorch.org/docs/1.12/fsdp.html
+                # need to pass optim_groups as optim_input
+                # optimizer_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_groups)
+            if args.rank == 0:
+                checkpoint_dict = {
+                    "model_state_dict": cpu_state,
+                    # "optimizer_state_dict": optimizer_state,
+                    "lr_scheduler_state_dict": lr_scheduler.state_dict(),
+                    "scaler_state_dict": scaler.state_dict(),
+                    "total_pile_token": total_pile_token,
+                    "total_laion_token": total_laion_token,
+                    "total_laion_sample": total_laion_sample,
+                    "total_step": total_step,
+                }
+                logging.info(f"Saving checkpoint to {args.run_name}/checkpoint_{total_step}.pt")
+                torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{total_step}.pt")
+                del checkpoint_dict
+                if args.delete_previous_checkpoint and total_step-args.save_interval > 0 and (total_step-args.save_interval) % args.skip_delete_pattern != 0:
+                    try:
+                        os.remove(f"{args.run_name}/checkpoint_{total_step-args.save_interval}.pt")
+                    except:
+                        pass
+            torch.distributed.barrier()
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
diff --git a/multimodal/playground.py b/multimodal/playground.py
new file mode 100644
index 0000000000000000000000000000000000000000..5601eda90d9759f56b6cefdb7b91129634b6cad4
--- /dev/null
+++ b/multimodal/playground.py
@@ -0,0 +1,11 @@
+import os
+import json
+
+if __name__ == "__main__":
+    blip2_cases = os.listdir("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal2/blip2_baseline/blip2_fail_case")
+    kmos2_cases = os.listdir("/gpfs/u/home/LMCG/LMCGljnn/scratch/code/unilm/kosmos-2/kmos2_fail_case")
+    blip2_failed_ids = set([int(c.split("_")[0]) for c in blip2_cases])
+    kmos2_failed_ids = set([int(c.split("_")[0]) for c in kmos2_cases])
+    both_failed_ids = list(blip2_failed_ids.intersection(kmos2_failed_ids))
+    print(both_failed_ids)
+    json.dump(both_failed_ids, open("both_failed_ids.json", "w"), indent=1)
diff --git a/multimodal/plot_log.py b/multimodal/plot_log.py
new file mode 100644
index 0000000000000000000000000000000000000000..1da8334d720112c0f87f5b916a2a533b11ce40cd
--- /dev/null
+++ b/multimodal/plot_log.py
@@ -0,0 +1,19 @@
+import matplotlib.pyplot as plt
+
+visual = []
+previsual = []
+with open("slurm-878067.out") as f:
+    lines = f.readlines()
+    for line in lines:
+        if "previsual_iou" in line:
+            line = line.strip().split("//")
+            pre_iou = float(line[-6].strip().split(" ")[-1])
+            iou = float(line[-5].strip().split(" ")[-1])
+            visual.append(iou)
+            previsual.append(pre_iou)
+
+plt.plot(visual, label="visual")
+plt.plot(previsual, label="previsual")
+plt.legend()
+plt.ylim(0, 4)
+plt.savefig("save.png")
diff --git a/multimodal/range_aro.sh b/multimodal/range_aro.sh
new file mode 100644
index 0000000000000000000000000000000000000000..30aedf5a0d3caab83c9911635037d14afee96d21
--- /dev/null
+++ b/multimodal/range_aro.sh
@@ -0,0 +1,20 @@
+sbatch -J aro2 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_2000.pt
+sbatch -J aro4 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_4000.pt
+sbatch -J aro6 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_6000.pt
+sbatch -J aro8 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_8000.pt
+sbatch -J aro10 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_10000.pt
+sbatch -J aro11 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_11000.pt
+sbatch -J aro12 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_12000.pt
+sbatch -J aro13 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_13000.pt
+sbatch -J aro14 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_14000.pt
+sbatch -J aro15 submit_eval.sh eval_aro.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_15000.pt
+
+
+sbatch -J aro3B2 submit_eval.sh eval_aro_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_2000.pt
+sbatch -J aro3B4 submit_eval.sh eval_aro_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_4000.pt
+sbatch -J aro3B6 submit_eval.sh eval_aro_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_6000.pt
+sbatch -J aro3B8 submit_eval.sh eval_aro_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_8000.pt
+sbatch -J aro3B10 submit_eval.sh eval_aro_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_10000.pt
+sbatch -J aro3B12 submit_eval.sh eval_aro_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_12000.pt
+sbatch -J aro3B14 submit_eval.sh eval_aro_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_14000.pt
+sbatch -J aro3B16 submit_eval.sh eval_aro_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_16000.pt
diff --git a/multimodal/range_vlc.sh b/multimodal/range_vlc.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d3c4138eb5fe1e3b89ef65b04e81860d5d0994db
--- /dev/null
+++ b/multimodal/range_vlc.sh
@@ -0,0 +1,20 @@
+sbatch -J vlc3B submit_eval.sh eval_vlc_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_2000.pt
+sbatch -J vlc3B submit_eval.sh eval_vlc_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_4000.pt
+sbatch -J vlc3B submit_eval.sh eval_vlc_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_6000.pt
+sbatch -J vlc3B submit_eval.sh eval_vlc_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_8000.pt
+sbatch -J vlc3B submit_eval.sh eval_vlc_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_10000.pt
+sbatch -J vlc3B submit_eval.sh eval_vlc_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_12000.pt
+sbatch -J vlc3B submit_eval.sh eval_vlc_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_14000.pt
+sbatch -J vlc3B submit_eval.sh eval_vlc_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_16000.pt
+
+
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_2000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_4000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_6000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_8000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_10000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_12000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_14000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_16000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_18000.pt
+sbatch -J vlc submit_eval.sh eval_vlc.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_20000.pt
diff --git a/multimodal/range_vqa.sh b/multimodal/range_vqa.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1062258553b419998ec2aad5f73d3ba45ce13a72
--- /dev/null
+++ b/multimodal/range_vqa.sh
@@ -0,0 +1,9 @@
+sbatch -J vqa submit_eval.sh eval_vqav2.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_12000.pt
+sbatch -J vqa submit_eval.sh eval_vqav2.sh checkpoints/091701_pythiaS_previsual_fix/checkpoint_18000.pt
+
+
+sbatch -J vqa3B submit_eval.sh eval_vqav2_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_10000.pt
+sbatch -J vqa3B submit_eval.sh eval_vqav2_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_12000.pt
+sbatch -J vqa3B submit_eval.sh eval_vqav2_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_14000.pt
+sbatch -J vqa3B submit_eval.sh eval_vqav2_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_16000.pt
+sbatch -J vqa3B submit_eval.sh eval_vqav2_3b.sh checkpoints/091801_pythia3b_previsual_fix/checkpoint_18000.pt
diff --git a/multimodal/requirements-dev.txt b/multimodal/requirements-dev.txt
new file mode 100644
index 0000000000000000000000000000000000000000..429f646f468bad4cc842cff203810b985d55d41d
--- /dev/null
+++ b/multimodal/requirements-dev.txt
@@ -0,0 +1,5 @@
+black
+mypy
+pylint
+pytest
+requests
\ No newline at end of file
diff --git a/multimodal/requirements.txt b/multimodal/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b00b128fec6d006ca6c955b330a0821dc01a06ef
--- /dev/null
+++ b/multimodal/requirements.txt
@@ -0,0 +1,16 @@
+einops
+einops-exts
+transformers
+torch
+torchvision
+pillow
+more-itertools
+datasets
+braceexpand
+webdataset
+wandb
+nltk
+scipy
+inflection
+sentencepiece
+open_clip_torch
diff --git a/multimodal/setup.py b/multimodal/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..8510a49f8dd9a4f0379304aa835277a7addfd9b6
--- /dev/null
+++ b/multimodal/setup.py
@@ -0,0 +1,57 @@
+from pathlib import Path
+
+from setuptools import find_packages, setup
+
+if __name__ == "__main__":
+    with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file:
+        long_description = file.read()
+
+    # TODO: This is a hack to get around the fact that we can't read the requirements.txt file, we should fix this.
+    # def _read_reqs(relpath):
+    #     fullpath = os.path.join(Path(__file__).parent, relpath)
+    #     with open(fullpath) as f:
+    #         return [
+    #             s.strip()
+    #             for s in f.readlines()
+    #             if (s.strip() and not s.startswith("#"))
+    #         ]
+
+    REQUIREMENTS = [
+        "einops",
+        "einops-exts",
+        "transformers",
+        "torch",
+        "torchvision",
+        "pillow",
+        "more-itertools",
+        "datasets",
+        "braceexpand",
+        "webdataset",
+        "wandb",
+        "nltk",
+        "scipy",
+        "inflection",
+        "sentencepiece",
+        "open_clip_torch",
+    ]
+
+    setup(
+        name="open_flamingo",
+        packages=find_packages(),
+        include_package_data=True,
+        version="0.0.2",
+        license="MIT",
+        description="An open-source framework for training large multimodal models",
+        long_description=long_description,
+        long_description_content_type="text/markdown",
+        data_files=[(".", ["README.md"])],
+        keywords=["machine learning"],
+        install_requires=REQUIREMENTS,
+        classifiers=[
+            "Development Status :: 4 - Beta",
+            "Intended Audience :: Developers",
+            "Topic :: Scientific/Engineering :: Artificial Intelligence",
+            "License :: OSI Approved :: MIT License",
+            "Programming Language :: Python :: 3.9",
+        ],
+    )
diff --git a/multimodal/tests/test_flamingo_model.py b/multimodal/tests/test_flamingo_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8cb70c42197d0a3ca8eb831ca87def07a43262
--- /dev/null
+++ b/multimodal/tests/test_flamingo_model.py
@@ -0,0 +1,77 @@
+# import unittest
+
+# import requests
+# from PIL import Image
+
+# from open_flamingo import create_model_and_transforms
+
+
+# class TestFlamingoModel(unittest.TestCase):
+#     def test_forward_pass(self):
+#         model, image_processor, tokenizer = create_model_and_transforms(
+#             clip_vision_encoder_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
+#             clip_processor_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
+#             lang_encoder_path="hf-internal-testing/tiny-random-OPTModel",
+#             tokenizer_path="hf-internal-testing/tiny-random-OPTModel",
+#         )
+
+#         image = Image.open(
+#             requests.get(
+#                 "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
+#             ).raw
+#         )
+#         vis_x = image_processor(images=[image, image], return_tensors="pt")[
+#             "pixel_values"
+#         ]
+#         vis_x = vis_x.unsqueeze(1).unsqueeze(1)
+#         lang_x = tokenizer(
+#             ["<|#image#|> A dog", "<|#image#|> A cat"],
+#             max_length=10,
+#             padding=True,
+#             truncation=True,
+#             return_tensors="pt",
+#         )
+
+#         # try batched forward pass
+#         model(vis_x, lang_x["input_ids"], attention_mask=lang_x["attention_mask"])
+
+#     def test_generate(self):
+#         model, image_processor, tokenizer = create_model_and_transforms(
+#             clip_vision_encoder_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
+#             clip_processor_path="hf-internal-testing/tiny-random-clip-zero-shot-image-classification",
+#             lang_encoder_path="hf-internal-testing/tiny-random-OPTModel",
+#             tokenizer_path="hf-internal-testing/tiny-random-OPTModel",
+#         )
+
+#         tokenizer.padding_side = (
+#             "left"  # we want to pad on the left side for generation
+#         )
+
+#         image = Image.open(
+#             requests.get(
+#                 "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
+#             ).raw
+#         )
+#         vis_x = image_processor(images=[image, image], return_tensors="pt")[
+#             "pixel_values"
+#         ]
+#         vis_x = vis_x.unsqueeze(1).unsqueeze(1)
+#         lang_x = tokenizer(
+#             ["<|#image#|> A dog", "<|#image#|> A cat <|endofchunk|>"],
+#             max_length=10,
+#             padding=True,
+#             truncation=True,
+#             return_tensors="pt",
+#         )
+
+#         # try batched generation
+#         model.generate(
+#             vis_x,
+#             lang_x["input_ids"],
+#             attention_mask=lang_x["attention_mask"],
+#             max_new_tokens=20,
+#         )
+
+
+# if __name__ == "__main__":
+#     unittest.main()
diff --git a/multimodal/tools/add_vg_to_blip2_data.py b/multimodal/tools/add_vg_to_blip2_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b1718fd35d61e8360999c83166e7581041a9690
--- /dev/null
+++ b/multimodal/tools/add_vg_to_blip2_data.py
@@ -0,0 +1,25 @@
+import os
+import shutil
+import glob
+import random
+from pprint import pprint
+
+DIR_VG = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_0826"
+DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_ground"
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_ground_with_vg_0826"
+
+
+if __name__ == "__main__":
+    os.makedirs(OUT_DIR, exist_ok=True)
+    blip2_tars = glob.glob(os.path.join(DIR, "*.tar"))
+    vg_tars = glob.glob(os.path.join(DIR_VG, "*", "*.tar"))
+    tars = []
+    tars.extend(blip2_tars)
+    tars.extend(vg_tars)
+    print(len(tars))
+    pprint(tars[:20])
+    pprint(tars[-20:])
+    for i, tar in enumerate(tars):
+        dst = os.path.join(OUT_DIR, f"{str(i).zfill(6)}.tar")
+        # print(tar, dst)
+        os.symlink(tar, dst)
diff --git a/multimodal/tools/check_refcoco.py b/multimodal/tools/check_refcoco.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e8e3bb651dd1adfce8a724f05b0e63866129545
--- /dev/null
+++ b/multimodal/tools/check_refcoco.py
@@ -0,0 +1,14 @@
+import os
+from tqdm import tqdm
+import numpy as np
+import sys
+
+if __name__ == "__main__":
+    captions = []
+    with open(sys.argv[1]) as f:
+        for line in tqdm(f):
+            line = line.rstrip().split("\t")
+            caption = line[2]
+            captions.append(caption)
+    lengths = [len(c.split(" ")) for c in captions]
+    print(np.mean(lengths))
diff --git a/multimodal/tools/convert_mmc4_to_wds.py b/multimodal/tools/convert_mmc4_to_wds.py
new file mode 100644
index 0000000000000000000000000000000000000000..1798e89403b8cf7b5606176449b9e859fd82adbc
--- /dev/null
+++ b/multimodal/tools/convert_mmc4_to_wds.py
@@ -0,0 +1,124 @@
+import argparse
+import base64
+import json
+import os
+import tarfile
+import uuid
+import zipfile
+import time
+
+import braceexpand
+import webdataset as wds
+from tqdm import tqdm
+from tqdm.contrib.concurrent import process_map
+
+arg_parser = argparse.ArgumentParser()
+arg_parser.add_argument("--output_dir", type=str)
+arg_parser.add_argument(
+    "--image_shards",
+    type=str,
+    help="Pass in a list of shards in the format path_to_shard/shard_{0..23098}_images_v2.tar",
+)
+arg_parser.add_argument(
+    "--doc_shards",
+    type=str,
+    help="Pass in a list of shards in the format path_to_shard/docs_shard_{0..23098}_v2.jsonl.zip",
+)
+arg_parser.add_argument(
+    "--thread",
+    type=int,
+    default=128,
+)
+args = arg_parser.parse_args()
+
+def get_txt_to_filename_dict(image_shards, disable_tqdm=False):
+    txt_to_filename_dict = {}
+    dataset = wds.WebDataset(image_shards).decode("pil").to_tuple("txt", "json")
+    for data in tqdm(dataset, disable=disable_tqdm):
+        txt = data[0].split(".")[0]
+        txt_to_filename_dict[txt] = data[1]['key']
+    return txt_to_filename_dict
+
+
+def single_thread(args):
+    i = args["i"]
+    output_dir = args["output_dir"]
+    doc_shards = args["doc_shards"]
+    image_shards = args["image_shards"]
+    if i == 0:
+        tqdm.write(f"output_dir: {output_dir}")
+        tqdm.write(f"doc_shards: {doc_shards[:5]}")
+        tqdm.write(f"image_shards: {image_shards[:5]}")
+    with wds.ShardWriter(os.path.join(output_dir, "%09d.tar"), maxcount=1000) as sink:
+        sink.verbose = False
+        for doc_shard, image_shard in tqdm(zip(doc_shards, image_shards), disable=(i != 0), total=len(doc_shards)):
+            # txt_to_filename_dict = get_txt_to_filename_dict(image_shard, disable_tqdm=(i != 0))
+            # image_tar = tarfile.open(image_shard)
+            # Open the ZIP archive and extract the JSON file
+            with zipfile.ZipFile(doc_shard, "r") as zip_file:
+                # Assumes the JSON file is the first file in the archive
+                json_filename = zip_file.namelist()[0]
+                with zip_file.open(json_filename, "r") as json_file:
+                    pbar = tqdm(json_file, disable=True)
+                    total_num = 0
+                    exist_num = 0
+                    for sample_data in pbar:
+                        # get image names from json
+                        sample_data = json.loads(sample_data)
+                        image_info = sample_data["image_info"]
+                        image_names = [image["image_name"] for image in image_info]
+
+                        # Add each image to the tar file
+                        for img_idx, image_name in enumerate(image_names):
+                            total_num += 1
+                            try:
+                                image = image_tar.extractfile(txt_to_filename_dict[image_name.split(".")[0]]+".jpg")
+                                # convert to base64
+                                image_bytes = image.read()
+                                image_base64 = base64.b64encode(image_bytes).decode("utf-8")
+                                exist_num += 1
+                            except:
+                                tqdm.write(f"{image_name.split('.')[0]}")
+                                image_base64 = "null"
+                            sample_data["image_info"][img_idx][
+                                "image_base64"
+                            ] = image_base64
+
+                        key_str = uuid.uuid4().hex
+                        sink.write({"__key__": key_str, "json": sample_data})
+                        pbar.set_description(f"{exist_num/total_num:.2f}")
+            # image_tar.close()
+
+
+def main():
+    timestamp = int(time.time())
+    os.makedirs(args.output_dir, exist_ok=True)
+    os.makedirs(os.path.join(args.output_dir, str(timestamp)), exist_ok=True)
+    tasks = []
+    for i in range(args.thread):
+        thread_dir = os.path.join(args.output_dir, str(timestamp), str(i))
+        os.makedirs(thread_dir, exist_ok=True)
+        tasks.append({
+            "i": i,
+            "output_dir": thread_dir,
+            "doc_shards": [],
+            "image_shards": [],
+        })
+
+    doc_shards = list(braceexpand.braceexpand(args.doc_shards))
+    image_shards = list(braceexpand.braceexpand(args.image_shards))
+
+    assert len(doc_shards) == len(
+        image_shards
+    ), "Each doc shards must have a corresponding image shard"
+
+    for i, (doc_shard, image_shard) in enumerate(zip(doc_shards, image_shards)):
+        tasks[i % args.thread]["doc_shards"].append(doc_shard)
+        tasks[i % args.thread]["image_shards"].append(image_shard)
+
+    # assert len(tasks) == args.thread
+    # process_map(single_thread, tasks, max_workers=args.thread, disable=True)
+    single_thread(tasks[0])
+
+if __name__ == "__main__":
+    main()
diff --git a/multimodal/tools/instruct_tuning_data/__init__.py b/multimodal/tools/instruct_tuning_data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/multimodal/tools/instruct_tuning_data/merge.py b/multimodal/tools/instruct_tuning_data/merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..682e02b866ba43e00d3a918fb2849dd04a95df8d
--- /dev/null
+++ b/multimodal/tools/instruct_tuning_data/merge.py
@@ -0,0 +1,13 @@
+import os
+import glob
+import random
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/all"
+
+
+if __name__ == "__main__":
+    tars = glob.glob(os.path.join("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct", "*", "*.tar"))
+    random.shuffle(tars)
+    os.makedirs(OUT_DIR, exist_ok=True)
+    for i, tar in enumerate(tars):
+        dst = os.path.join(OUT_DIR, f"{str(i).zfill(6)}.tar")
+        os.symlink(tar, dst)
diff --git a/multimodal/tools/instruct_tuning_data/pisc.py b/multimodal/tools/instruct_tuning_data/pisc.py
new file mode 100644
index 0000000000000000000000000000000000000000..562cec805725481fad3277543f4c8d5947473e6c
--- /dev/null
+++ b/multimodal/tools/instruct_tuning_data/pisc.py
@@ -0,0 +1,51 @@
+import json
+import os
+from tqdm import tqdm
+import webdataset as wds
+from utils import MAXCOUNT, NAMING, check_sample
+import numpy as np
+PISC_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/PISC"
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/eval/pisc"
+
+rel_id_to_type = ["friends", "family", "couple", "professional", "commercial", "no relation"]
+
+if __name__ == "__main__":
+    os.makedirs(OUT_DIR, exist_ok=True)
+    annotation_image_info = json.load(open(os.path.join(PISC_ROOT, "annotation_image_info.json")))
+    relationships = json.load(open(os.path.join(PISC_ROOT, "relationship.json")))
+    relationship_trainidx = json.load(open(os.path.join(PISC_ROOT, "relationship_split", "relation_trainidx.json")))
+    relationship_testidx = json.load(open(os.path.join(PISC_ROOT, "relationship_split", "relation_testidx.json")))
+    data = {}
+    uuid = 0
+    with wds.ShardWriter(os.path.join(OUT_DIR, NAMING), maxcount=MAXCOUNT**3) as sink:
+        for annotation in tqdm(annotation_image_info):
+            imgH = annotation["imgH"]
+            imgW = annotation["imgW"]
+            id = annotation["id"]
+            bbox = annotation["bbox"] # xyxy
+            if str(id) not in relationships:
+                tqdm.write(f"skip {id} due to not in relationships")
+                continue
+            if str(id) not in relationship_testidx:
+                tqdm.write(f"skip {id} due to not in train set")
+                continue
+            relationship = relationships[str(id)]
+            for rel in relationship:
+                type = rel_id_to_type[relationship[rel] - 1]
+                A_id, B_id = list(map(int, rel.split(" ")))
+                A_box = np.array(bbox[A_id - 1]).astype(float) / np.array([imgW, imgH, imgW, imgH]).astype(float)
+                B_box = np.array(bbox[B_id - 1]).astype(float) / np.array([imgW, imgH, imgW, imgH]).astype(float)
+                data = [A_box, B_box, type]
+                image_path = os.path.join(PISC_ROOT, "image", str(id).zfill(5)+".jpg")
+                dataset = "pisc_relation_split"
+                key = f"{dataset}_{id}_{uuid}"
+                uuid += 1
+                assert os.path.exists(image_path)
+                sample = {
+                    "__key__": key,
+                    "image_path.txt": image_path,
+                    "dataset.txt": dataset,
+                    "data.pyd": data,
+                }
+                check_sample(sample)
+                sink.write(sample)
diff --git a/multimodal/tools/instruct_tuning_data/utils.py b/multimodal/tools/instruct_tuning_data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..37237978a6f7089ba4315fa5971b7342e1fcbb7f
--- /dev/null
+++ b/multimodal/tools/instruct_tuning_data/utils.py
@@ -0,0 +1,11 @@
+import os
+
+MAXCOUNT = 1000
+NAMING = "%06d.tar"
+
+def check_sample(sample):
+    assert "__key__" in sample
+    assert "image_path.txt" in sample
+    assert os.path.exists(sample["image_path.txt"])
+    assert "dataset.txt" in sample
+    assert "data.pyd" in sample
diff --git a/multimodal/tools/instruct_tuning_data/vg_relation.py b/multimodal/tools/instruct_tuning_data/vg_relation.py
new file mode 100644
index 0000000000000000000000000000000000000000..661365b7de2fb8acdd7c3e70d15530ace0775e20
--- /dev/null
+++ b/multimodal/tools/instruct_tuning_data/vg_relation.py
@@ -0,0 +1,238 @@
+import os
+import orjson
+import json
+import webdataset as wds
+from tqdm import tqdm, trange
+import h5py
+import numpy as np
+from utils import MAXCOUNT, NAMING, check_sample
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/instruct/vg_relation"
+BOX_SCALE = 512
+
+def load_image_filenames(image_file, image_dir):
+    """
+    Loads the image filenames from visual genome from the JSON file that contains them.
+    This matches the preprocessing in scene-graph-TF-release/data_tools/vg_to_imdb.py.
+    :param image_file: JSON file. Elements contain the param "image_id".
+    :param image_dir: directory where the VisualGenome images are located
+    :return: List of filenames corresponding to the good images
+    """
+    with open(image_file, 'r') as f:
+        im_data = json.load(f)
+
+    corrupted_ims = ['1592.jpg', '1722.jpg', '4616.jpg', '4617.jpg']
+    fns = []
+    for i, img in enumerate(tqdm(im_data)):
+        basename = '{}.jpg'.format(img['image_id'])
+        height = int(img['height'])
+        width = int(img['width'])
+        if basename in corrupted_ims:
+            continue
+
+        filename = os.path.join(image_dir, basename)
+        if os.path.exists(filename):
+            fns.append([filename, height, width])
+    assert len(fns) == 108073
+    return fns
+
+
+def load_graphs(graphs_file, mode='train', num_im=-1, num_val_im=0, filter_empty_rels=True,
+                filter_non_overlap=False):
+    """
+    Load the file containing the GT boxes and relations, as well as the dataset split
+    :param graphs_file: HDF5
+    :param mode: (train, val, or test)
+    :param num_im: Number of images we want
+    :param num_val_im: Number of validation images
+    :param filter_empty_rels: (will be filtered otherwise.)
+    :param filter_non_overlap: If training, filter images that dont overlap.
+    :return: image_index: numpy array corresponding to the index of images we're using
+             boxes: List where each element is a [num_gt, 4] array of ground 
+                    truth boxes (x1, y1, x2, y2)
+             gt_classes: List where each element is a [num_gt] array of classes
+             relationships: List where each element is a [num_r, 3] array of 
+                    (box_ind_1, box_ind_2, predicate) relationships
+    """
+    if mode not in ('train', 'val', 'test'):
+        raise ValueError('{} invalid'.format(mode))
+
+    roi_h5 = h5py.File(graphs_file, 'r')
+    data_split = roi_h5['split'][:]
+    split = 2 if mode == 'test' else 0
+    split_mask = data_split == split
+
+    # Filter out images without bounding boxes
+    split_mask &= roi_h5['img_to_first_box'][:] >= 0
+    if filter_empty_rels:
+        split_mask &= roi_h5['img_to_first_rel'][:] >= 0
+
+    image_index = np.where(split_mask)[0]
+    if num_im > -1:
+        image_index = image_index[:num_im]
+    if num_val_im > 0:
+        if mode == 'val':
+            image_index = image_index[:num_val_im]
+        elif mode == 'train':
+            image_index = image_index[num_val_im:]
+
+
+    split_mask = np.zeros_like(data_split).astype(bool)
+    split_mask[image_index] = True
+
+    # Get box information
+    all_labels = roi_h5['labels'][:, 0]
+    all_boxes = roi_h5['boxes_{}'.format(BOX_SCALE)][:]  # will index later
+    assert np.all(all_boxes[:, :2] >= 0)  # sanity check
+    assert np.all(all_boxes[:, 2:] > 0)  # no empty box
+
+    # convert from xc, yc, w, h to x1, y1, x2, y2
+    all_boxes[:, :2] = all_boxes[:, :2] - all_boxes[:, 2:] / 2
+    all_boxes[:, 2:] = all_boxes[:, :2] + all_boxes[:, 2:]
+
+    im_to_first_box = roi_h5['img_to_first_box'][:][split_mask]
+    im_to_last_box = roi_h5['img_to_last_box'][:][split_mask]
+    im_to_first_rel = roi_h5['img_to_first_rel'][:][split_mask]
+    im_to_last_rel = roi_h5['img_to_last_rel'][:][split_mask]
+
+    # load relation labels
+    _relations = roi_h5['relationships'][:]
+    _relation_predicates = roi_h5['predicates'][:, 0]
+    assert (im_to_first_rel.shape[0] == im_to_last_rel.shape[0])
+    assert (_relations.shape[0] == _relation_predicates.shape[0])  # sanity check
+
+    # Get everything by image.
+    boxes = []
+    gt_classes = []
+    relationships = []
+    for i in trange(len(image_index)):
+        boxes_i = all_boxes[im_to_first_box[i]:im_to_last_box[i] + 1, :]
+        gt_classes_i = all_labels[im_to_first_box[i]:im_to_last_box[i] + 1]
+
+        if im_to_first_rel[i] >= 0:
+            predicates = _relation_predicates[im_to_first_rel[i]:im_to_last_rel[i] + 1]
+            obj_idx = _relations[im_to_first_rel[i]:im_to_last_rel[i] + 1] - im_to_first_box[i]
+            assert np.all(obj_idx >= 0)
+            assert np.all(obj_idx < boxes_i.shape[0])
+            rels = np.column_stack((obj_idx, predicates))
+        else:
+            assert not filter_empty_rels
+            rels = np.zeros((0, 3), dtype=np.int32)
+
+        if filter_non_overlap:
+            raise NotImplementedError
+            assert mode == 'train'
+            inters = bbox_overlaps(boxes_i, boxes_i)
+            rel_overs = inters[rels[:, 0], rels[:, 1]]
+            inc = np.where(rel_overs > 0.0)[0]
+
+            if inc.size > 0:
+                rels = rels[inc]
+            else:
+                split_mask[image_index[i]] = 0
+                continue
+
+        boxes.append(boxes_i)
+        gt_classes.append(gt_classes_i)
+        relationships.append(rels)
+
+    return split_mask, boxes, gt_classes, relationships
+
+
+def load_info(info_file):
+    """
+    Loads the file containing the visual genome label meanings
+    :param info_file: JSON
+    :return: ind_to_classes: sorted list of classes
+             ind_to_predicates: sorted list of predicates
+    """
+    info = json.load(open(info_file, 'r'))
+    info['label_to_idx']['__background__'] = 0
+    info['predicate_to_idx']['__background__'] = 0
+
+    class_to_ind = info['label_to_idx']
+    predicate_to_ind = info['predicate_to_idx']
+    ind_to_classes = sorted(class_to_ind, key=lambda k: class_to_ind[k])
+    ind_to_predicates = sorted(predicate_to_ind, key=lambda k: predicate_to_ind[k])
+
+    return ind_to_classes, ind_to_predicates
+
+
+if __name__ == "__main__":
+    root = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg"
+    filenames = load_image_filenames(os.path.join(root, "image_data.json"), os.path.join(root, "VG_100K"))
+    split_mask, boxes, gt_classes, relationships = load_graphs(
+        graphs_file=os.path.join(root, "VG-SGG.h5"),
+        mode="train",
+    )
+    split_filenames = []
+    for i, mask in enumerate(split_mask):
+        if mask:
+            split_filenames.append(filenames[i])
+    filenames = split_filenames
+    ind_to_classes, ind_to_predicates = load_info(os.path.join(root, "VG-SGG-dicts.json"))
+    assert len(filenames) == len(boxes)
+    assert len(filenames) == len(gt_classes)
+    assert len(filenames) == len(relationships)
+    uuid = 0
+    os.makedirs(OUT_DIR, exist_ok=True)
+    pbar = tqdm()
+    with wds.ShardWriter(os.path.join(OUT_DIR, NAMING), maxcount=MAXCOUNT) as sink:
+        for box, box_class, relationship, (filename, height, width) in zip(boxes, gt_classes, relationships, filenames):
+            size = float(BOX_SCALE) / max(height, width)
+            size = np.array([width, height, width, height]) * size
+            box = (box.astype(float) / size).clip(0, 1)
+            for relation in relationship:
+                box1_id = relation[0]
+                box2_id = relation[1]
+                predicate = ind_to_predicates[relation[2]]
+                box1 = [box[box1_id], ind_to_classes[box_class[box1_id]]]
+                box2 = [box[box2_id], ind_to_classes[box_class[box2_id]]]
+                data = [box1, box2, predicate]
+                dataset = "vg_relation"
+                image_path = filename
+                key = f"{dataset}_{uuid}"
+                uuid += 1
+                pbar.update()
+                sample = {
+                    "__key__": key,
+                    "image_path.txt": image_path,
+                    "dataset.txt": dataset,
+                    "data.pyd": data,
+                }
+                check_sample(sample)
+                sink.write(sample)
+
+
+# if __name__ == "__main__":
+#     root = "/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg"
+#     relationships = orjson.loads(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/relationships.json").read())
+#     image_data = orjson.loads(open("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg/image_data.json").read())
+#     image_id_to_filename = {}
+#     image_id_to_wh = {}
+#     for image in tqdm(image_data):
+#         image_id = image["image_id"]
+#         subfolder, filename = image['url'].split("/")[-2:]
+#         image_id_to_filename[image_id] = os.path.join(root, subfolder, filename)
+#         image_id_to_wh[image_id] = (image["width"], image["height"])
+#     unique_predicates = []
+#     # with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=500) as sink:
+#     for relation_per_image in tqdm(relationships):
+#         image_id = relation_per_image["image_id"]
+#         for relation in relation_per_image["relationships"]:
+#             predicate = relation["predicate"]
+#             unique_predicates.append(predicate)
+#             object = {
+#                 "name": relation["object"]["name"],
+#                 "x": relation["object"]["x"],
+#                 "y": relation["object"]["y"],
+#                 "w": relation["object"]["w"],
+#                 "h": relation["object"]["h"],
+#             }
+#             subject = {
+#                 "name": relation["subject"]["name"],
+#                 "x": relation["subject"]["x"],
+#                 "y": relation["subject"]["y"],
+#                 "w": relation["subject"]["w"],
+#                 "h": relation["subject"]["h"],
+#             }
+
diff --git a/multimodal/tools/make_gqa_val.py b/multimodal/tools/make_gqa_val.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/multimodal/tools/make_mmc4_global_table.py b/multimodal/tools/make_mmc4_global_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..9701f099d7e0d7c7653bac18eefce0068c40088c
--- /dev/null
+++ b/multimodal/tools/make_mmc4_global_table.py
@@ -0,0 +1,31 @@
+import webdataset as wds
+import glob
+import os
+from tqdm import tqdm
+from tqdm.contrib.concurrent import process_map
+import pickle as pkl
+
+
+def single_thread(filename):
+    id_table = {}
+    dataset = wds.WebDataset(filename).decode().to_tuple("json")
+    for data in dataset:
+        data = data[0]
+        image_id = data["caption"].split(".")[0]
+        image_key = data["key"]
+        tarfile = os.path.basename(filename)
+        if image_id not in id_table:
+            id_table[image_id] = [tarfile, image_key]
+    return id_table
+
+if __name__ == "__main__":
+    filenames = sorted(glob.glob("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/mmc4/images/*.tar"))[:16000]
+    print("start from", filenames[0])
+    print("to", filenames[-1])
+    id_tables = process_map(single_thread, filenames, max_workers=64)
+    id_table = {}
+    for table in tqdm(id_tables):
+        id_table.update(table)
+    print("total unique image:", len(id_table))
+    pkl.dump(id_table, open("mmc4_id_table.pkl", "wb"))
+    print("DONE")
diff --git a/multimodal/tools/make_soft_link.py b/multimodal/tools/make_soft_link.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3aa8b909b5d132700190701a609a1c1ee7bb79d
--- /dev/null
+++ b/multimodal/tools/make_soft_link.py
@@ -0,0 +1,26 @@
+import os
+import shutil
+import glob
+import random
+
+DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw"
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_mini_dataset_full_karpathy"
+
+
+if __name__ == "__main__":
+    os.makedirs(OUT_DIR, exist_ok=True)
+    cc3m_tars = glob.glob(os.path.join(DIR, "cc3m", "cc3m_*", "*.tar"))
+    cc12m_tars = glob.glob(os.path.join(DIR, "cc12m", "tars", "*.tar"))
+    coco_tars = glob.glob(os.path.join(DIR, "karpathy_coco_wds_full", "*.tar"))
+    vg_tars = glob.glob(os.path.join(DIR, "vg_wds_full", "*.tar"))
+    tars = []
+    tars.extend(cc3m_tars)
+    tars.extend(cc12m_tars)
+    tars.extend(coco_tars)
+    tars.extend(vg_tars)
+    random.shuffle(tars)
+    for i, tar in enumerate(tars):
+        dst = os.path.join(OUT_DIR, f"{str(i).zfill(6)}.tar")
+        print(tar, dst)
+        os.symlink(tar, dst)
+
diff --git a/multimodal/tools/make_soft_link_blip2_data.py b/multimodal/tools/make_soft_link_blip2_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..67c108652301f951a228a41281814485792c0b43
--- /dev/null
+++ b/multimodal/tools/make_soft_link_blip2_data.py
@@ -0,0 +1,30 @@
+import os
+import shutil
+import glob
+import random
+from pprint import pprint
+
+DIR_COCO_VG = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw"
+DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining"
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_all_data_ground"
+
+
+if __name__ == "__main__":
+    os.makedirs(OUT_DIR, exist_ok=True)
+    ccs_tars = glob.glob(os.path.join(DIR, "ccs_synthetic_filtered_large_ground", "*.tar"))
+    coco_tars = glob.glob(os.path.join(DIR_COCO_VG, "karpathy_coco_wds_full_ground", "*.tar"))
+    vg_tars = glob.glob(os.path.join(DIR_COCO_VG, "vg_wds_full_ground", "*.tar"))
+    laion_part_tars = glob.glob(os.path.join(DIR, "laion_synthetic_filtered_large", "all_ground", "*.tar"))
+    tars = []
+    tars.extend(ccs_tars)
+    for _ in range(5):
+        tars.extend(coco_tars)
+    tars.extend(vg_tars)
+    tars.extend(laion_part_tars)
+    random.shuffle(tars)
+    print(len(tars))
+    pprint(tars[:20])
+    for i, tar in enumerate(tars):
+        dst = os.path.join(OUT_DIR, f"{str(i).zfill(6)}.tar")
+        # print(tar, dst)
+        os.symlink(tar, dst)
diff --git a/multimodal/tools/make_soft_link_laion.py b/multimodal/tools/make_soft_link_laion.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd7a872153fa4082f6756d902ae3198a43fad79f
--- /dev/null
+++ b/multimodal/tools/make_soft_link_laion.py
@@ -0,0 +1,23 @@
+import os
+import shutil
+import glob
+import random
+from pprint import pprint
+
+DIR_COCO_VG = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw"
+DIR =     "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/"
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/blip2_pretraining/laion_synthetic_filtered_large/all"
+
+
+if __name__ == "__main__":
+    os.makedirs(OUT_DIR, exist_ok=True)
+    tars = []
+    for i in range(10):
+        laion_part_tars = glob.glob(os.path.join(DIR, "laion_synthetic_filtered_large", f"part{i}", "*.tar"))
+        tars.extend(laion_part_tars)
+    print(len(tars))
+    pprint(tars[:20])
+    for i, tar in enumerate(tars):
+        dst = os.path.join(OUT_DIR, f"{str(i).zfill(6)}.tar")
+        # print(tar, dst)
+        os.symlink(tar, dst)
diff --git a/multimodal/tools/make_vqav2_ft_dataset.py b/multimodal/tools/make_vqav2_ft_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ee095b398727a48b2aad2d59c7f8af232b606cb
--- /dev/null
+++ b/multimodal/tools/make_vqav2_ft_dataset.py
@@ -0,0 +1,24 @@
+import webdataset as wds
+import os
+from tqdm import tqdm
+from PIL import Image
+from io import BytesIO
+import base64
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vqav2_train_wds"
+TOTAL = 1828467
+
+if __name__ == "__main__":
+    with wds.ShardWriter(os.path.join(OUT_DIR, "%06d.tar"), maxcount=10000) as sink:
+        sink.verbose = False
+        f = open("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vqav2_ofa/vqa_train.tsv")
+        for data in tqdm(f, total=TOTAL):
+            data = data.rstrip().split("\t")
+            id1 = data[0]
+            id2 = data[1]
+            question = data[2]
+            answer = data[3].split("|!+")[-1]
+            image = data[5]
+            id3 = data[6]
+            image = Image.open(BytesIO(base64.urlsafe_b64decode(image))).convert("RGB")
+            caption = f"Question: {question.strip()} Answer: {answer.strip()}"
+            sink.write({"__key__": f"vqav2_{id1}_{id2}_{id3}", "jpg": image, "txt": caption})
\ No newline at end of file
diff --git a/multimodal/tools/prepare_mini_blip2_dataset.py b/multimodal/tools/prepare_mini_blip2_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ffaee6c64ac04d650673503d75e405503ffbcd5
--- /dev/null
+++ b/multimodal/tools/prepare_mini_blip2_dataset.py
@@ -0,0 +1,178 @@
+import webdataset as wds
+import glob
+import os
+from tqdm import tqdm
+import orjson as json
+import itertools
+from PIL import Image
+import numpy as np
+from typing import List
+
+class Generator():
+    def __init__(self, dataset_name):
+        self.dataset_name = dataset_name
+        self.is_end = False
+
+class CC3MGenerator(Generator):
+    def __init__(self, root: str, dataset_name="cc3m"):
+        super().__init__(dataset_name=dataset_name)
+        self.tars = glob.glob(os.path.join(root, "cc3m_*", "*.tar"))
+
+    def __len__(self):
+        return 3000000
+
+    def __iter__(self):
+        for tar in self.tars:
+            dataset = wds.WebDataset(tar).decode("pilrgb").to_tuple("jpg;png;jpeg", "txt")
+            for data in dataset:
+                yield [self.dataset_name] + list(data)
+        self.is_end = True
+
+class CC12MGenerator(CC3MGenerator):
+    def __init__(self, root: str):
+        super().__init__(root, "cc12m")
+        self.tars = glob.glob(os.path.join(root, "*.tar"))
+
+    def __len__(self):
+        return 12000000
+
+class COCOGenerator(Generator):
+    def __init__(self, anno: str, image_dir):
+        super().__init__(dataset_name="coco")
+        data = json.loads(open(anno).read())
+        self.annotations = data["annotations"]
+        self.image_id_to_filename = {}
+        for image in data["images"]:
+            file_name = image["file_name"]
+            image_id = image["id"]
+            self.image_id_to_filename[image_id] = os.path.join(image_dir, file_name)
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def __iter__(self):
+        for anno in self.annotations:
+            image_id = anno["image_id"]
+            caption = anno["caption"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                continue
+            yield [self.dataset_name, image, caption]
+        self.is_end = True
+
+
+class KarpathyCOCOGenerator(Generator):
+    def __init__(self, anno="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/tools/coco_karpathy_train.json", image_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco/images"):
+        super().__init__(dataset_name="coco")
+        data = json.loads(open(anno).read())
+        self.annotations = data
+        self.image_id_to_filename = {}
+        for d in data:
+            self.image_id_to_filename[d["image_id"]] = os.path.join(image_dir, d["image"])
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def __iter__(self):
+        for anno in self.annotations:
+            image_id = anno["image_id"]
+            caption = anno["caption"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                print(self.image_id_to_filename[image_id])
+            yield [self.dataset_name, image, caption]
+        self.is_end = True
+
+
+class VisualGenomeGenerator(Generator):
+    def __init__(self, root: str):
+        super().__init__(dataset_name="vg")
+        data = json.loads(open(os.path.join(root, "region_descriptions.json")).read())
+        image_data = json.loads(open(os.path.join(root, "image_data.json")).read())
+        self.image_id_to_filename = {}
+        self.image_id_to_wh = {}
+        for image in image_data:
+            image_id = image["image_id"]
+            subfolder, filename = image['url'].split("/")[-2:]
+            self.image_id_to_filename[image_id] = os.path.join(root, subfolder, filename)
+            self.image_id_to_wh[image_id] = (image["width"], image["height"])
+        self.regions = []
+        total = 0
+        total_image = 0
+        used_image = 0
+        for xx in data:
+            total_image += 1
+            flag = False
+            for region in xx['regions']:
+                total += 1
+                region_w = int(region["width"])
+                region_h = int(region["height"])
+                image_w = self.image_id_to_wh[region["image_id"]][0]
+                image_h = self.image_id_to_wh[region["image_id"]][1]
+                if region_w * region_h < (image_w * image_h) * 0.2:
+                    continue
+                self.regions.append(region)
+                flag = True
+            if flag:
+                used_image += 1
+        print("valid region", len(self.regions), total, len(self.regions) / total)
+        print("valid image", used_image, total_image, used_image / total_image)
+
+    def __len__(self):
+        return len(self.regions)
+
+    def __iter__(self):
+        for region in self.regions:
+            image_id = region["image_id"]
+            phrase = region["phrase"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                continue
+            yield [self.dataset_name, image, phrase]
+        self.is_end = True
+
+class ShuffleGenerator():
+    def __init__(self, generators: List[Generator], p: List[int]):
+        self.generators = generators
+        self.p = list(np.array(p) / sum(p))
+        self.ids = list(range(len(self.generators)))
+        print("rebalance", self.ids, self.p)
+
+    def __len__(self):
+        return sum([len(g) for g in self.generators])
+
+    def __iter__(self):
+        while True:
+            if len(self.ids) == 0:
+                break
+            id = np.random.choice(self.ids, p=self.p)
+            gen = self.generators[id]
+            if gen.is_end:
+                print(gen.dataset_name, "is all done")
+                del self.ids[id]
+                del self.p[id]
+                self.p = list(np.array(self.p) / sum(p))
+                print("rebalance", self.ids, self.p)
+            else:
+                return iter(gen)
+
+
+if __name__ == "__main__":
+    OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_withBox_wds"
+    os.makedirs(OUT_DIR, exist_ok=True)
+    # cc3m_generator = CC3MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc3m")
+    # cc12m_generator = CC12MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc12m/tars")
+    # coco_generator = KarpathyCOCOGenerator()
+    visual_genome_generator = VisualGenomeGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg")
+    # generators = [cc3m_generator, cc12m_generator, coco_generator, visual_genome_generator]
+    # p = [len(generator) for generator in generators]
+    # dataset = ShuffleGenerator(generators, p)
+
+    with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=8500) as sink:
+        sink.verbose = False
+        for i, data in enumerate(tqdm(visual_genome_generator)):
+            dataset_name, image, caption = data
+            sink.write({"__key__": f"{dataset_name}_{i}_containBox", "jpg": image, "txt": caption})
diff --git a/multimodal/tools/prepare_pile.py b/multimodal/tools/prepare_pile.py
new file mode 100644
index 0000000000000000000000000000000000000000..e35fba8e1cecb33f551e57b190756b40139bafee
--- /dev/null
+++ b/multimodal/tools/prepare_pile.py
@@ -0,0 +1,31 @@
+import datasets
+import os
+from tqdm import tqdm
+import webdataset as wds
+import json
+
+DATASET_ROOT = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/the_pile/all/train"
+OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/the_pile"
+SAMPLE_PER_SHARD = 100000
+
+if __name__ == "__main__":
+    os.makedirs(OUT_DIR)
+    print("load dataset...")
+    pile = datasets.load_from_disk(DATASET_ROOT)
+    total_num = pile.num_rows
+    print("total num:", total_num)
+    num = 0
+    pbar = tqdm(total=total_num)
+    with wds.ShardWriter(OUT_DIR+"/%05d.tar", maxcount=SAMPLE_PER_SHARD, encoder=False) as sink:
+        for sample in pile.iter(4096):
+            for text, meta in zip(sample["text"], sample["meta"]):
+                pbar.update(1)
+                if meta.get("pile_set_name", None) == "Github":
+                    continue
+                num += 1
+                sink.write({
+                    '__key__': str(num),
+                    'txt': text.encode("utf-8"),
+                    'json': json.dumps(meta, indent=4).encode("utf-8"),
+                })
+    print(f"{num} out of {total_num} is written")
diff --git a/multimodal/tools/prepare_vg_regional_box.py b/multimodal/tools/prepare_vg_regional_box.py
new file mode 100644
index 0000000000000000000000000000000000000000..c67dfc3554703a011955018b3c9b3d42b35ca15a
--- /dev/null
+++ b/multimodal/tools/prepare_vg_regional_box.py
@@ -0,0 +1,120 @@
+import webdataset as wds
+import glob
+import os
+from tqdm import tqdm
+import orjson as json
+import itertools
+from PIL import Image
+import numpy as np
+from typing import List
+import cv2
+import random
+from tqdm.contrib.concurrent import process_map
+from copy import deepcopy
+
+class Generator():
+    def __init__(self, dataset_name):
+        self.dataset_name = dataset_name
+        self.is_end = False
+
+
+class VisualGenomeGenerator(Generator):
+    def __init__(self, root: str):
+        super().__init__(dataset_name="vg")
+        data = json.loads(open(os.path.join(root, "region_descriptions.json")).read())
+        image_data = json.loads(open(os.path.join(root, "image_data.json")).read())
+        self.image_id_to_filename = {}
+        self.image_id_to_wh = {}
+        for image in image_data:
+            image_id = image["image_id"]
+            subfolder, filename = image['url'].split("/")[-2:]
+            self.image_id_to_filename[image_id] = os.path.join(root, subfolder, filename)
+            self.image_id_to_wh[image_id] = (image["width"], image["height"])
+        self.regions = []
+        total = 0
+        total_image = 0
+        used_image = 0
+        for xx in tqdm(data):
+            total_image += 1
+            flag = False
+            for region in xx['regions']:
+                total += 1
+                region_w = int(region["width"])
+                region_h = int(region["height"])
+                x = int(region["x"])
+                y = int(region["y"])
+                image_w = self.image_id_to_wh[region["image_id"]][0]
+                image_h = self.image_id_to_wh[region["image_id"]][1]
+                region_w /= image_w
+                region_h /= image_h
+                x /= image_w
+                y /= image_h
+                if region_w * region_h < 1 / (16*16*4):
+                    continue
+                if " is" in region["phrase"] or " are" in region["phrase"]:
+                    continue
+                region["norm_xywh"] = (x, y, region_w, region_h)
+                self.regions.append(region)
+                flag = True
+            if flag:
+                used_image += 1
+        random.shuffle(self.regions)
+        print("valid region", len(self.regions), total, len(self.regions) / total)
+        print("valid image", used_image, total_image, used_image / total_image)
+
+    def __len__(self):
+        return len(self.regions)
+
+    def __iter__(self):
+        for region in self.regions:
+            image_id = region["image_id"]
+            phrase = region["phrase"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                continue
+            image = image.resize((224, 224))
+            x, y, region_w, region_h = region["norm_xywh"]
+            x1 = int(x * 224)
+            y1 = int(y * 224)
+            x2 = int(x1 + region_w * 224)
+            y2 = int(y1 + region_h * 224)
+            yield [self.dataset_name, image, phrase, np.array([x1, y1, x2, y2]), image_id]
+        self.is_end = True
+
+
+def handle(args):
+    dataset_name = "vg"
+    iii, regions, image_id_to_filename = args
+    if iii == 0:
+        print(regions[:10])
+    os.makedirs(os.path.join(OUT_DIR, str(iii)), exist_ok=True)
+    with wds.ShardWriter(os.path.join(OUT_DIR, str(iii), "%06d.tar"), maxcount=8500) as sink:
+        sink.verbose = False
+        for i, region in enumerate(tqdm(regions, disable=(iii != 0))):
+            image_id = region["image_id"]
+            phrase = region["phrase"]
+            image = Image.open(image_id_to_filename[image_id])
+            image = image.resize((224, 224))
+            x, y, region_w, region_h = region["norm_xywh"]
+            x1 = int(x * 224)
+            y1 = int(y * 224)
+            x2 = int(x1 + region_w * 224)
+            y2 = int(y1 + region_h * 224)
+            dataset_name, image, caption, xyxy, image_id = [dataset_name, image, phrase, np.array([x1, y1, x2, y2]), image_id]
+            sink.write({"__key__": f"{dataset_name}_{i}_containBox", "jpg": image, "txt": caption, "boxes.pyd": xyxy, "logits.pyd": xyxy})
+            if i % 200 == 0 and iii == 0:
+                tqdm.write(f"{caption} {xyxy}")
+
+
+if __name__ == "__main__":
+    OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_0826"
+    os.makedirs(OUT_DIR, exist_ok=True)
+    visual_genome_generator = VisualGenomeGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg")
+    N_PROC = 150
+    data_list = []
+    for i in range(N_PROC):
+        data_list.append([i, [], deepcopy(visual_genome_generator.image_id_to_filename)])
+    for i, region in enumerate(visual_genome_generator.regions):
+        data_list[i % N_PROC][1].append(region)
+    process_map(handle, data_list, max_workers=N_PROC, disable=True)
diff --git a/multimodal/tools/prepare_vg_with_box.py b/multimodal/tools/prepare_vg_with_box.py
new file mode 100644
index 0000000000000000000000000000000000000000..e626a6072a324351a741d0a5827961bd57355822
--- /dev/null
+++ b/multimodal/tools/prepare_vg_with_box.py
@@ -0,0 +1,205 @@
+import webdataset as wds
+import glob
+import os
+from tqdm import tqdm
+import orjson as json
+import itertools
+from PIL import Image
+import numpy as np
+from typing import List
+import cv2
+import random
+
+class Generator():
+    def __init__(self, dataset_name):
+        self.dataset_name = dataset_name
+        self.is_end = False
+
+class CC3MGenerator(Generator):
+    def __init__(self, root: str, dataset_name="cc3m"):
+        super().__init__(dataset_name=dataset_name)
+        self.tars = glob.glob(os.path.join(root, "cc3m_*", "*.tar"))
+
+    def __len__(self):
+        return 3000000
+
+    def __iter__(self):
+        for tar in self.tars:
+            dataset = wds.WebDataset(tar).decode("pilrgb").to_tuple("jpg;png;jpeg", "txt")
+            for data in dataset:
+                yield [self.dataset_name] + list(data)
+        self.is_end = True
+
+class CC12MGenerator(CC3MGenerator):
+    def __init__(self, root: str):
+        super().__init__(root, "cc12m")
+        self.tars = glob.glob(os.path.join(root, "*.tar"))
+
+    def __len__(self):
+        return 12000000
+
+class COCOGenerator(Generator):
+    def __init__(self, anno: str, image_dir):
+        super().__init__(dataset_name="coco")
+        data = json.loads(open(anno).read())
+        self.annotations = data["annotations"]
+        self.image_id_to_filename = {}
+        for image in data["images"]:
+            file_name = image["file_name"]
+            image_id = image["id"]
+            self.image_id_to_filename[image_id] = os.path.join(image_dir, file_name)
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def __iter__(self):
+        for anno in self.annotations:
+            image_id = anno["image_id"]
+            caption = anno["caption"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                continue
+            yield [self.dataset_name, image, caption]
+        self.is_end = True
+
+
+class KarpathyCOCOGenerator(Generator):
+    def __init__(self, anno="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/tools/coco_karpathy_train.json", image_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco/images"):
+        super().__init__(dataset_name="coco")
+        data = json.loads(open(anno).read())
+        self.annotations = data
+        self.image_id_to_filename = {}
+        for d in data:
+            self.image_id_to_filename[d["image_id"]] = os.path.join(image_dir, d["image"])
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def __iter__(self):
+        for anno in self.annotations:
+            image_id = anno["image_id"]
+            caption = anno["caption"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                print(self.image_id_to_filename[image_id])
+            yield [self.dataset_name, image, caption]
+        self.is_end = True
+
+
+class VisualGenomeGenerator(Generator):
+    def __init__(self, root: str):
+        super().__init__(dataset_name="vg")
+        data = json.loads(open(os.path.join(root, "region_descriptions.json")).read())
+        image_data = json.loads(open(os.path.join(root, "image_data.json")).read())
+        self.image_id_to_filename = {}
+        self.image_id_to_wh = {}
+        for image in image_data:
+            image_id = image["image_id"]
+            subfolder, filename = image['url'].split("/")[-2:]
+            self.image_id_to_filename[image_id] = os.path.join(root, subfolder, filename)
+            self.image_id_to_wh[image_id] = (image["width"], image["height"])
+        self.regions = []
+        total = 0
+        total_image = 0
+        used_image = 0
+        for xx in data:
+            total_image += 1
+            flag = False
+            for region in xx['regions']:
+                total += 1
+                region_w = int(region["width"])
+                region_h = int(region["height"])
+                x = int(region["x"])
+                y = int(region["y"])
+                image_w = self.image_id_to_wh[region["image_id"]][0]
+                image_h = self.image_id_to_wh[region["image_id"]][1]
+                region_w /= image_w
+                region_h /= image_h
+                x /= image_w
+                y /= image_h
+                if region_w * region_h < 0.1:
+                    continue
+                if " is" in region["phrase"] or " are" in region["phrase"] or len(region["phrase"].split(" ")) <= 7:
+                    continue
+                region["norm_xywh"] = (x, y, region_w, region_h)
+                self.regions.append(region)
+                flag = True
+            if flag:
+                used_image += 1
+        random.shuffle(self.regions)
+        print("valid region", len(self.regions), total, len(self.regions) / total)
+        print("valid image", used_image, total_image, used_image / total_image)
+
+    def __len__(self):
+        return len(self.regions)
+
+    def __iter__(self):
+        for region in self.regions:
+            image_id = region["image_id"]
+            phrase = region["phrase"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                continue
+            image = image.resize((224, 224))
+            x, y, region_w, region_h = region["norm_xywh"]
+            x1 = int(x * 224)
+            y1 = int(y * 224)
+            x2 = int(x1 + region_w * 224)
+            y2 = int(y1 + region_h * 224)
+            # open_cv_image = np.array(image)
+            # # Convert RGB to BGR
+            # open_cv_image = open_cv_image[:, :, ::-1].copy()
+            # open_cv_image = cv2.rectangle(open_cv_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
+            # cv2.imwrite("vg.jpg", open_cv_image)
+            # import pdb; pdb.set_trace()
+            yield [self.dataset_name, image, phrase, np.array([x1, y1, x2, y2]), image_id]
+        self.is_end = True
+
+class ShuffleGenerator():
+    def __init__(self, generators: List[Generator], p: List[int]):
+        self.generators = generators
+        self.p = list(np.array(p) / sum(p))
+        self.ids = list(range(len(self.generators)))
+        print("rebalance", self.ids, self.p)
+
+    def __len__(self):
+        return sum([len(g) for g in self.generators])
+
+    def __iter__(self):
+        while True:
+            if len(self.ids) == 0:
+                break
+            id = np.random.choice(self.ids, p=self.p)
+            gen = self.generators[id]
+            if gen.is_end:
+                print(gen.dataset_name, "is all done")
+                del self.ids[id]
+                del self.p[id]
+                self.p = list(np.array(self.p) / sum(p))
+                print("rebalance", self.ids, self.p)
+            else:
+                return iter(gen)
+
+
+if __name__ == "__main__":
+    OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_withBox_L7_wds"
+    os.makedirs(OUT_DIR, exist_ok=True)
+    # cc3m_generator = CC3MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc3m")
+    # cc12m_generator = CC12MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc12m/tars")
+    # coco_generator = KarpathyCOCOGenerator()
+    visual_genome_generator = VisualGenomeGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg")
+    # generators = [cc3m_generator, cc12m_generator, coco_generator, visual_genome_generator]
+    # p = [len(generator) for generator in generators]
+    # dataset = ShuffleGenerator(generators, p)
+
+    with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=8500) as sink:
+        sink.verbose = False
+        pbar = tqdm(visual_genome_generator)
+        for i, data in enumerate(pbar):
+            dataset_name, image, caption, xyxy, image_id = data
+            sink.write({"__key__": f"{dataset_name}_{i}_containBox", "jpg": image, "txt": caption, "xyxy.pyd": xyxy})
+            if i % 200 == 0:
+                tqdm.write(f"{caption} {xyxy}")
diff --git a/multimodal/tools/prepare_vg_with_box2.py b/multimodal/tools/prepare_vg_with_box2.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a96c40b32944d0f9e8da4c2205a446f6fc6d92f
--- /dev/null
+++ b/multimodal/tools/prepare_vg_with_box2.py
@@ -0,0 +1,205 @@
+import webdataset as wds
+import glob
+import os
+from tqdm import tqdm
+import orjson as json
+import itertools
+from PIL import Image
+import numpy as np
+from typing import List
+import cv2
+import random
+
+class Generator():
+    def __init__(self, dataset_name):
+        self.dataset_name = dataset_name
+        self.is_end = False
+
+class CC3MGenerator(Generator):
+    def __init__(self, root: str, dataset_name="cc3m"):
+        super().__init__(dataset_name=dataset_name)
+        self.tars = glob.glob(os.path.join(root, "cc3m_*", "*.tar"))
+
+    def __len__(self):
+        return 3000000
+
+    def __iter__(self):
+        for tar in self.tars:
+            dataset = wds.WebDataset(tar).decode("pilrgb").to_tuple("jpg;png;jpeg", "txt")
+            for data in dataset:
+                yield [self.dataset_name] + list(data)
+        self.is_end = True
+
+class CC12MGenerator(CC3MGenerator):
+    def __init__(self, root: str):
+        super().__init__(root, "cc12m")
+        self.tars = glob.glob(os.path.join(root, "*.tar"))
+
+    def __len__(self):
+        return 12000000
+
+class COCOGenerator(Generator):
+    def __init__(self, anno: str, image_dir):
+        super().__init__(dataset_name="coco")
+        data = json.loads(open(anno).read())
+        self.annotations = data["annotations"]
+        self.image_id_to_filename = {}
+        for image in data["images"]:
+            file_name = image["file_name"]
+            image_id = image["id"]
+            self.image_id_to_filename[image_id] = os.path.join(image_dir, file_name)
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def __iter__(self):
+        for anno in self.annotations:
+            image_id = anno["image_id"]
+            caption = anno["caption"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                continue
+            yield [self.dataset_name, image, caption]
+        self.is_end = True
+
+
+class KarpathyCOCOGenerator(Generator):
+    def __init__(self, anno="/gpfs/u/home/LMCG/LMCGljnn/scratch/code/multimodal/tools/coco_karpathy_train.json", image_dir="/gpfs/u/home/LMCG/LMCGljnn/scratch/.cache/lavis/coco/images"):
+        super().__init__(dataset_name="coco")
+        data = json.loads(open(anno).read())
+        self.annotations = data
+        self.image_id_to_filename = {}
+        for d in data:
+            self.image_id_to_filename[d["image_id"]] = os.path.join(image_dir, d["image"])
+
+    def __len__(self):
+        return len(self.annotations)
+
+    def __iter__(self):
+        for anno in self.annotations:
+            image_id = anno["image_id"]
+            caption = anno["caption"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                print(self.image_id_to_filename[image_id])
+            yield [self.dataset_name, image, caption]
+        self.is_end = True
+
+
+class VisualGenomeGenerator(Generator):
+    def __init__(self, root: str):
+        super().__init__(dataset_name="vg")
+        data = json.loads(open(os.path.join(root, "region_descriptions.json")).read())
+        image_data = json.loads(open(os.path.join(root, "image_data.json")).read())
+        self.image_id_to_filename = {}
+        self.image_id_to_wh = {}
+        for image in image_data:
+            image_id = image["image_id"]
+            subfolder, filename = image['url'].split("/")[-2:]
+            self.image_id_to_filename[image_id] = os.path.join(root, subfolder, filename)
+            self.image_id_to_wh[image_id] = (image["width"], image["height"])
+        self.regions = []
+        total = 0
+        total_image = 0
+        used_image = 0
+        for xx in data:
+            total_image += 1
+            flag = False
+            for region in xx['regions']:
+                total += 1
+                region_w = int(region["width"])
+                region_h = int(region["height"])
+                x = int(region["x"])
+                y = int(region["y"])
+                image_w = self.image_id_to_wh[region["image_id"]][0]
+                image_h = self.image_id_to_wh[region["image_id"]][1]
+                region_w /= image_w
+                region_h /= image_h
+                x /= image_w
+                y /= image_h
+                if region_w * region_h < 0.1:
+                    continue
+                if " is" in region["phrase"] or " are" in region["phrase"]:
+                    continue
+                region["norm_xywh"] = (x, y, region_w, region_h)
+                self.regions.append(region)
+                flag = True
+            if flag:
+                used_image += 1
+        random.shuffle(self.regions)
+        print("valid region", len(self.regions), total, len(self.regions) / total)
+        print("valid image", used_image, total_image, used_image / total_image)
+
+    def __len__(self):
+        return len(self.regions)
+
+    def __iter__(self):
+        for region in self.regions:
+            image_id = region["image_id"]
+            phrase = region["phrase"]
+            try:
+                image = Image.open(self.image_id_to_filename[image_id])
+            except:
+                continue
+            image = image.resize((224, 224))
+            x, y, region_w, region_h = region["norm_xywh"]
+            x1 = int(x * 224)
+            y1 = int(y * 224)
+            x2 = int(x1 + region_w * 224)
+            y2 = int(y1 + region_h * 224)
+            # open_cv_image = np.array(image)
+            # # Convert RGB to BGR
+            # open_cv_image = open_cv_image[:, :, ::-1].copy()
+            # open_cv_image = cv2.rectangle(open_cv_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
+            # cv2.imwrite("vg.jpg", open_cv_image)
+            # import pdb; pdb.set_trace()
+            yield [self.dataset_name, image, phrase, np.array([x1, y1, x2, y2]), image_id]
+        self.is_end = True
+
+class ShuffleGenerator():
+    def __init__(self, generators: List[Generator], p: List[int]):
+        self.generators = generators
+        self.p = list(np.array(p) / sum(p))
+        self.ids = list(range(len(self.generators)))
+        print("rebalance", self.ids, self.p)
+
+    def __len__(self):
+        return sum([len(g) for g in self.generators])
+
+    def __iter__(self):
+        while True:
+            if len(self.ids) == 0:
+                break
+            id = np.random.choice(self.ids, p=self.p)
+            gen = self.generators[id]
+            if gen.is_end:
+                print(gen.dataset_name, "is all done")
+                del self.ids[id]
+                del self.p[id]
+                self.p = list(np.array(self.p) / sum(p))
+                print("rebalance", self.ids, self.p)
+            else:
+                return iter(gen)
+
+
+if __name__ == "__main__":
+    OUT_DIR = "/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/vg_withBox_wds"
+    os.makedirs(OUT_DIR, exist_ok=True)
+    # cc3m_generator = CC3MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc3m")
+    # cc12m_generator = CC12MGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch-shared/junyan/raw/cc12m/tars")
+    # coco_generator = KarpathyCOCOGenerator()
+    visual_genome_generator = VisualGenomeGenerator("/gpfs/u/home/LMCG/LMCGljnn/scratch/datasets/raw/vg")
+    # generators = [cc3m_generator, cc12m_generator, coco_generator, visual_genome_generator]
+    # p = [len(generator) for generator in generators]
+    # dataset = ShuffleGenerator(generators, p)
+
+    with wds.ShardWriter(os.path.join(OUT_DIR, "%05d.tar"), maxcount=8500) as sink:
+        sink.verbose = False
+        pbar = tqdm(visual_genome_generator)
+        for i, data in enumerate(pbar):
+            dataset_name, image, caption, xyxy, image_id = data
+            sink.write({"__key__": f"{dataset_name}_{i}_containBox", "jpg": image, "txt": caption, "xyxy.pyd": xyxy})
+            if i % 200 == 0:
+                tqdm.write(f"{caption} {xyxy}")
diff --git a/multimodal/tools/vg.jpg b/multimodal/tools/vg.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..800cec4ae3cbb5d4062b8bdc200587d8fd2aa7e7
Binary files /dev/null and b/multimodal/tools/vg.jpg differ