diff --git a/.gitattributes b/.gitattributes index c7d9f3332a950355d5a77d85000f05e6f45435ea..e339d11e08bb73a20958b21166d3937c9ae479a5 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.psd filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md index 03cb7eb74f4ec004d3bf7178414b16d3631b0186..139dc454d481b5ea557249742bb32a1a9bad13ad 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ --- -title: Image Editing With GPT3 -emoji: 🐨 +title: X Decoder +emoji: 📈 colorFrom: purple -colorTo: blue +colorTo: gray sdk: gradio -sdk_version: 3.16.1 +sdk_version: 3.14.0 app_file: app.py pinned: false license: afl-3.0 diff --git a/__init__.py b/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0b85755b8d5e772a0f27efa063e180d4d83e47 --- /dev/null +++ b/app.py @@ -0,0 +1,98 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Jianwei Yang (jianwyan@microsoft.com), Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import os +os.system("python -m pip install git+https://github.com/MaureenZOU/detectron2-xyz.git") + +import gradio as gr +import torch +import argparse + +from xdecoder.BaseModel import BaseModel +from xdecoder import build_model +from utils.distributed import init_distributed +from utils.arguments import load_opt_from_config_files + +from tasks import * + +def parse_option(): + parser = argparse.ArgumentParser('X-Decoder All-in-One Demo', add_help=False) + parser.add_argument('--conf_files', default="configs/xdecoder/svlp_focalt_lang.yaml", metavar="FILE", help='path to config file', ) + args = parser.parse_args() + + return args + +''' +build args +''' +args = parse_option() +opt = load_opt_from_config_files(args.conf_files) +opt = init_distributed(opt) + +# META DATA +pretrained_pth_last = os.path.join("xdecoder_focalt_last.pt") +pretrained_pth_novg = os.path.join("xdecoder_focalt_last_novg.pt") + +if not os.path.exists(pretrained_pth_last): + os.system("wget {}".format("https://projects4jw.blob.core.windows.net/x-decoder/release/xdecoder_focalt_last.pt")) + +if not os.path.exists(pretrained_pth_novg): + os.system("wget {}".format("https://projects4jw.blob.core.windows.net/x-decoder/release/xdecoder_focalt_last_novg.pt")) + + +''' +build model +''' +model_last = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth_last).eval().cuda() +model_cap = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth_novg).eval().cuda() + +with torch.no_grad(): + model_last.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=True) + model_cap.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(["background", "background"], is_eval=True) + +''' +inference model +''' + +@torch.no_grad() +def inference(image, instruction, *args, **kwargs): + image = image.convert("RGB") + with torch.autocast(device_type='cuda', dtype=torch.float16): + return referring_inpainting_gpt3(model_last, image, instruction, *args, **kwargs) + +''' +launch app +''' +title = "X-Decoder + GPT-3 Instructional Image Editing" +description = "

Project Page | Paper | Github Repo | Video

" + +article = "The Demo is Run on X-Decoder (Focal-T)." + +inputs = [gr.inputs.Image(type='pil'), gr.Textbox(label="instruction")] +gr.Interface( + fn=inference, + inputs=inputs, + outputs=[ + gr.outputs.Image( + type="pil", + label="segmentation results"), + gr.Textbox(label="text restuls"), + gr.outputs.Image( + type="pil", + label="inpainting results"), + ], + examples=[ + ["./images/apples.jpg", "change green apple to a red apple"], + ["./images/girl_and_two_boys.png", "remove the boy with blue backbag"], + ["./images/dog.png", "remove the dog"], + ], + title=title, + description=description, + article=article, + allow_flagging='never', + cache_examples=True, +).launch(share=True) \ No newline at end of file diff --git a/configs/xdecoder/svlp_focalt_lang.yaml b/configs/xdecoder/svlp_focalt_lang.yaml new file mode 100755 index 0000000000000000000000000000000000000000..8010124cad660e07e8de7fae1f91166ff1ac834d --- /dev/null +++ b/configs/xdecoder/svlp_focalt_lang.yaml @@ -0,0 +1,110 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +################## +# Task settings +################## +VERBOSE: true +MODEL: + NAME: xdecoder_model + HEAD: xdecoder_head + DIM_PROJ: 512 + BACKBONE_DIM: 768 + TEXT: + ARCH: vlpencoder + NAME: transformer + TOKENIZER: clip + CONTEXT_LENGTH: 77 # 77 + WIDTH: 512 + HEADS: 8 + LAYERS: 12 # 6 + AUTOGRESSIVE: True + BACKBONE: + NAME: focal_dw + PRETRAINED: '' + LOAD_PRETRAINED: false + FOCAL: + PRETRAIN_IMG_SIZE: 224 + PATCH_SIZE: 4 + EMBED_DIM: 96 + DEPTHS: [2, 2, 6, 2] + FOCAL_LEVELS: [3, 3, 3, 3] + FOCAL_WINDOWS: [3, 3, 3, 3] + DROP_PATH_RATE: 0.3 + MLP_RATIO: 4.0 + DROP_RATE: 0.0 + PATCH_NORM: True + USE_CONV_EMBED: True + SCALING_MODULATOR: True + USE_CHECKPOINT: False + USE_POSTLN: true + USE_POSTLN_IN_MODULATION: false + USE_LAYERSCALE: True + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + OUT_INDICES: [0, 1, 2, 3] + ENCODER: + NAME: transformer_encoder_fpn + IGNORE_VALUE: 255 + NUM_CLASSES: 133 + LOSS_WEIGHT: 1.0 + CONVS_DIM: 512 + MASK_DIM: 512 + NORM: "GN" + IN_FEATURES: ["res2", "res3", "res4", "res5"] + DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"] + COMMON_STRIDE: 4 + TRANSFORMER_ENC_LAYERS: 6 + DECODER: + NAME: xdecoder + TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" + MASK: True + GROUNDING: + ENABLED: True + MAX_LEN: 5 + TEXT_WEIGHT: 2.0 + CLASS_WEIGHT: 0.5 + DETECTION: False + CAPTION: + ENABLED: True + PHRASE_PROB: 0.0 + SIM_THRES: 0.95 + CAPTIONING: + ENABLED: True + STEP: 50 + RETRIEVAL: + ENABLED: True + DIM_IMG: 768 + ENSEMBLE: True + HIDDEN_DIM: 512 + NUM_OBJECT_QUERIES: 101 + NHEADS: 8 + DROPOUT: 0.0 + DIM_FEEDFORWARD: 2048 + PRE_NORM: False + ENFORCE_INPUT_PROJ: False + SIZE_DIVISIBILITY: 32 + TRAIN_NUM_POINTS: 12544 + OVERSAMPLE_RATIO: 3.0 + IMPORTANCE_SAMPLE_RATIO: 0.75 + DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query + TOP_GROUNDING_LAYERS: 3 + TOP_CAPTION_LAYERS: 3 + TOP_CAPTIONING_LAYERS: 3 + TOP_RETRIEVAL_LAYERS: 3 + TOP_OPENIMAGE_LAYERS: 10 + TEST: + SEMANTIC_ON: True + INSTANCE_ON: True + PANOPTIC_ON: True + OVERLAP_THRESHOLD: 0.8 + OBJECT_MASK_THRESHOLD: 0.4 + SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false + DETECTIONS_PER_IMAGE: 100 + +INPUT: + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] \ No newline at end of file diff --git a/images/apples.jpg b/images/apples.jpg new file mode 100644 index 0000000000000000000000000000000000000000..52c854d1e31430e1100ef0231c70016b8ba835b8 Binary files /dev/null and b/images/apples.jpg differ diff --git a/images/coco/000.jpg b/images/coco/000.jpg new file mode 100755 index 0000000000000000000000000000000000000000..af38ff573f505d7fb24aca165fe36a02f4a0561e Binary files /dev/null and b/images/coco/000.jpg differ diff --git a/images/coco/001.jpg b/images/coco/001.jpg new file mode 100755 index 0000000000000000000000000000000000000000..9b68704d20116fa24faade4c6a90250231c020fa Binary files /dev/null and b/images/coco/001.jpg differ diff --git a/images/coco/002.jpg b/images/coco/002.jpg new file mode 100755 index 0000000000000000000000000000000000000000..d4dffd00b209fabcacb664f3fa7d8b4ea781a0e9 Binary files /dev/null and b/images/coco/002.jpg differ diff --git a/images/coco/003.jpg b/images/coco/003.jpg new file mode 100755 index 0000000000000000000000000000000000000000..06d3accacfd2ac0e1b501b29e5cb9f73dd819861 Binary files /dev/null and b/images/coco/003.jpg differ diff --git a/images/coco/004.jpg b/images/coco/004.jpg new file mode 100755 index 0000000000000000000000000000000000000000..36dece6f1f94660514399899ef8f430539995927 Binary files /dev/null and b/images/coco/004.jpg differ diff --git a/images/coco/005.jpg b/images/coco/005.jpg new file mode 100755 index 0000000000000000000000000000000000000000..31c6139af798e434b8aae1143cadaf258a8149db Binary files /dev/null and b/images/coco/005.jpg differ diff --git a/images/coco/006.jpg b/images/coco/006.jpg new file mode 100755 index 0000000000000000000000000000000000000000..1f0385e5735f38f24109808c2b38e46572f8e7ff Binary files /dev/null and b/images/coco/006.jpg differ diff --git a/images/coco/007.jpg b/images/coco/007.jpg new file mode 100755 index 0000000000000000000000000000000000000000..57344ef00f67f202a1d03b9fa1b93134b977e25d Binary files /dev/null and b/images/coco/007.jpg differ diff --git a/images/coco/008.jpg b/images/coco/008.jpg new file mode 100755 index 0000000000000000000000000000000000000000..665e1652e5bc26b88ddfc3e86120309e72cc4889 Binary files /dev/null and b/images/coco/008.jpg differ diff --git a/images/coco/009.jpg b/images/coco/009.jpg new file mode 100755 index 0000000000000000000000000000000000000000..03b87a9ab6fad9546d9abeeab9b07dbd5b489d89 Binary files /dev/null and b/images/coco/009.jpg differ diff --git a/images/coco/010.jpg b/images/coco/010.jpg new file mode 100755 index 0000000000000000000000000000000000000000..7767b0481e9047839d6c4fe6ae4df24c06fa9848 Binary files /dev/null and b/images/coco/010.jpg differ diff --git a/images/coco/011.jpg b/images/coco/011.jpg new file mode 100755 index 0000000000000000000000000000000000000000..9541405d88ae703a4bc013a0dfdc973576aa852f Binary files /dev/null and b/images/coco/011.jpg differ diff --git a/images/coco/012.jpg b/images/coco/012.jpg new file mode 100755 index 0000000000000000000000000000000000000000..6e4a51585d5bb1f668dc2c3aedd058ade022f58e Binary files /dev/null and b/images/coco/012.jpg differ diff --git a/images/coco/013.jpg b/images/coco/013.jpg new file mode 100755 index 0000000000000000000000000000000000000000..0eaac632ff14f217801ea35e69740f1e0b0233bc Binary files /dev/null and b/images/coco/013.jpg differ diff --git a/images/coco/014.jpg b/images/coco/014.jpg new file mode 100755 index 0000000000000000000000000000000000000000..389684fac4ad79ea89ad8a22227e7c2816d449b5 Binary files /dev/null and b/images/coco/014.jpg differ diff --git a/images/coco/015.jpg b/images/coco/015.jpg new file mode 100755 index 0000000000000000000000000000000000000000..47215aff0b3d89d8d9d873764c40d91ff1029087 Binary files /dev/null and b/images/coco/015.jpg differ diff --git a/images/coco/016.jpg b/images/coco/016.jpg new file mode 100755 index 0000000000000000000000000000000000000000..7c45b46f760f81f183b6870c7d874e01709fdfb7 Binary files /dev/null and b/images/coco/016.jpg differ diff --git a/images/coco/017.jpg b/images/coco/017.jpg new file mode 100755 index 0000000000000000000000000000000000000000..5e4b77deef254b0d45ae563d0480d1e17b6e2789 Binary files /dev/null and b/images/coco/017.jpg differ diff --git a/images/coco/018.jpg b/images/coco/018.jpg new file mode 100755 index 0000000000000000000000000000000000000000..dd6dc5c23f57c71babcd26e63bd2668a858fb6ee Binary files /dev/null and b/images/coco/018.jpg differ diff --git a/images/coco/019.jpg b/images/coco/019.jpg new file mode 100755 index 0000000000000000000000000000000000000000..4e3bcc1a82881db0d727b4940267b3d4b5465a37 Binary files /dev/null and b/images/coco/019.jpg differ diff --git a/images/coco/020.jpg b/images/coco/020.jpg new file mode 100755 index 0000000000000000000000000000000000000000..e3cdc1c190987575cf811b4859adb015b67437ca Binary files /dev/null and b/images/coco/020.jpg differ diff --git a/images/coco/021.jpg b/images/coco/021.jpg new file mode 100755 index 0000000000000000000000000000000000000000..527cef4736b529bc5833f9404d3b08d138f98d43 Binary files /dev/null and b/images/coco/021.jpg differ diff --git a/images/coco/022.jpg b/images/coco/022.jpg new file mode 100755 index 0000000000000000000000000000000000000000..28de1c765dfdae21a502d5c6c3d493cd5a55856f Binary files /dev/null and b/images/coco/022.jpg differ diff --git a/images/coco/023.jpg b/images/coco/023.jpg new file mode 100755 index 0000000000000000000000000000000000000000..890f544bdc9bc55bde6ec7c7e3691dbea21e4765 Binary files /dev/null and b/images/coco/023.jpg differ diff --git a/images/coco/024.jpg b/images/coco/024.jpg new file mode 100755 index 0000000000000000000000000000000000000000..f9926d71dd23f18d4041d0425a386a36aad205a2 Binary files /dev/null and b/images/coco/024.jpg differ diff --git a/images/coco/025.jpg b/images/coco/025.jpg new file mode 100755 index 0000000000000000000000000000000000000000..63cdbd33d452ee5f2abb08ace9bf9416e8b83363 Binary files /dev/null and b/images/coco/025.jpg differ diff --git a/images/coco/026.jpg b/images/coco/026.jpg new file mode 100755 index 0000000000000000000000000000000000000000..05a3883234eae9be70fb5c8c256d7da58036c449 Binary files /dev/null and b/images/coco/026.jpg differ diff --git a/images/coco/027.jpg b/images/coco/027.jpg new file mode 100755 index 0000000000000000000000000000000000000000..c93d00d5f873b7d2bfa6ac0753a9cb35536bfd55 Binary files /dev/null and b/images/coco/027.jpg differ diff --git a/images/coco/028.jpg b/images/coco/028.jpg new file mode 100755 index 0000000000000000000000000000000000000000..f003f982f242d069331e31bb154f597049dd70dc Binary files /dev/null and b/images/coco/028.jpg differ diff --git a/images/coco/029.jpg b/images/coco/029.jpg new file mode 100755 index 0000000000000000000000000000000000000000..5e71fda27908c37102501a7adde22b3d99e8438d Binary files /dev/null and b/images/coco/029.jpg differ diff --git a/images/coco/030.jpg b/images/coco/030.jpg new file mode 100755 index 0000000000000000000000000000000000000000..e7f633f9a0bfdecc8c38909d49810f06e8e4d246 Binary files /dev/null and b/images/coco/030.jpg differ diff --git a/images/coco/031.jpg b/images/coco/031.jpg new file mode 100755 index 0000000000000000000000000000000000000000..f6c472549a9c03f0f243d86372dfcfb26b7bdb97 Binary files /dev/null and b/images/coco/031.jpg differ diff --git a/images/coco/032.jpg b/images/coco/032.jpg new file mode 100755 index 0000000000000000000000000000000000000000..0fd9657ed002d9591107fb7175d84c5587d90955 Binary files /dev/null and b/images/coco/032.jpg differ diff --git a/images/coco/033.jpg b/images/coco/033.jpg new file mode 100755 index 0000000000000000000000000000000000000000..cca519968adc695c5f5e3636f20405ce04ef75db Binary files /dev/null and b/images/coco/033.jpg differ diff --git a/images/coco/034.jpg b/images/coco/034.jpg new file mode 100755 index 0000000000000000000000000000000000000000..1baebb5fbc1ebb838f5e1d2a47354715bd2c718d Binary files /dev/null and b/images/coco/034.jpg differ diff --git a/images/coco/035.jpg b/images/coco/035.jpg new file mode 100755 index 0000000000000000000000000000000000000000..a3ce2201cad5dda06080290287d38312514bb223 Binary files /dev/null and b/images/coco/035.jpg differ diff --git a/images/coco/036.jpg b/images/coco/036.jpg new file mode 100755 index 0000000000000000000000000000000000000000..bc1ff9fffc8662d957a9fe15237ec985c96eb2ad Binary files /dev/null and b/images/coco/036.jpg differ diff --git a/images/coco/037.jpg b/images/coco/037.jpg new file mode 100755 index 0000000000000000000000000000000000000000..1182e979223e8e0a9073e59e4701611f85245cf3 Binary files /dev/null and b/images/coco/037.jpg differ diff --git a/images/coco/038.jpg b/images/coco/038.jpg new file mode 100755 index 0000000000000000000000000000000000000000..1428c8d72d432df4bfd167f71813b720476bb0df Binary files /dev/null and b/images/coco/038.jpg differ diff --git a/images/coco/039.jpg b/images/coco/039.jpg new file mode 100755 index 0000000000000000000000000000000000000000..c16d9e599d845bfbf945d8fdcda98012338cfebe Binary files /dev/null and b/images/coco/039.jpg differ diff --git a/images/coco/040.jpg b/images/coco/040.jpg new file mode 100755 index 0000000000000000000000000000000000000000..f8c10c7c7f832be20595fe1f582f9b7b2e42b2f7 Binary files /dev/null and b/images/coco/040.jpg differ diff --git a/images/coco/041.jpg b/images/coco/041.jpg new file mode 100755 index 0000000000000000000000000000000000000000..65ecf1b5da6fc9cf958a263885f27dd3f260b048 Binary files /dev/null and b/images/coco/041.jpg differ diff --git a/images/coco/042.jpg b/images/coco/042.jpg new file mode 100755 index 0000000000000000000000000000000000000000..ceac8978d3b72a4a1eec17cf4355ab7d20de1a6f Binary files /dev/null and b/images/coco/042.jpg differ diff --git a/images/coco/043.jpg b/images/coco/043.jpg new file mode 100755 index 0000000000000000000000000000000000000000..3c9de73569c928fc4937df914ea99ab5c45de0cc Binary files /dev/null and b/images/coco/043.jpg differ diff --git a/images/coco/044.jpg b/images/coco/044.jpg new file mode 100755 index 0000000000000000000000000000000000000000..cc78c5d44252c507c0ae0c54fd74542743e6f910 Binary files /dev/null and b/images/coco/044.jpg differ diff --git a/images/coco/045.jpg b/images/coco/045.jpg new file mode 100755 index 0000000000000000000000000000000000000000..feceffce5ca35f547c2e9043ea5fa864c6cf96d5 Binary files /dev/null and b/images/coco/045.jpg differ diff --git a/images/coco/046.jpg b/images/coco/046.jpg new file mode 100755 index 0000000000000000000000000000000000000000..74b22da715685f4854b46bdf69d6db9ae2b97b68 Binary files /dev/null and b/images/coco/046.jpg differ diff --git a/images/coco/047.jpg b/images/coco/047.jpg new file mode 100755 index 0000000000000000000000000000000000000000..daadf5279748b32bbd150bc52b0a365906df4db6 Binary files /dev/null and b/images/coco/047.jpg differ diff --git a/images/coco/048.jpg b/images/coco/048.jpg new file mode 100755 index 0000000000000000000000000000000000000000..c8f2aa41dde43ea864aefbe1a59b938d93766cb2 Binary files /dev/null and b/images/coco/048.jpg differ diff --git a/images/coco/049.jpg b/images/coco/049.jpg new file mode 100755 index 0000000000000000000000000000000000000000..8cfc5971523a1904929227094253f32135698475 Binary files /dev/null and b/images/coco/049.jpg differ diff --git a/images/coco/050.jpg b/images/coco/050.jpg new file mode 100755 index 0000000000000000000000000000000000000000..914a9f4cd9157d32e2a7cbb7b581d6f41e462c5e Binary files /dev/null and b/images/coco/050.jpg differ diff --git a/images/coco/051.jpg b/images/coco/051.jpg new file mode 100755 index 0000000000000000000000000000000000000000..f3ac1a476f8720990f7a813b04d6a56cc6dc8953 Binary files /dev/null and b/images/coco/051.jpg differ diff --git a/images/coco/052.jpg b/images/coco/052.jpg new file mode 100755 index 0000000000000000000000000000000000000000..ee0901142974bcc57b6f6bd82fb5857f3227cf37 Binary files /dev/null and b/images/coco/052.jpg differ diff --git a/images/coco/053.jpg b/images/coco/053.jpg new file mode 100755 index 0000000000000000000000000000000000000000..494b48c99e0ceb33a28e9f99d566525eaed310ea Binary files /dev/null and b/images/coco/053.jpg differ diff --git a/images/coco/054.jpg b/images/coco/054.jpg new file mode 100755 index 0000000000000000000000000000000000000000..6011c1c0ab284065d31194ce8f1ce1b1d1ef78da Binary files /dev/null and b/images/coco/054.jpg differ diff --git a/images/coco/055.jpg b/images/coco/055.jpg new file mode 100755 index 0000000000000000000000000000000000000000..4af1cb2af083ebd0dfabf9f9933ce4b3c5c1884d Binary files /dev/null and b/images/coco/055.jpg differ diff --git a/images/coco/056.jpg b/images/coco/056.jpg new file mode 100755 index 0000000000000000000000000000000000000000..9763bf5b6dd0065830c8b12fa3b306dcb16717af Binary files /dev/null and b/images/coco/056.jpg differ diff --git a/images/coco/057.jpg b/images/coco/057.jpg new file mode 100755 index 0000000000000000000000000000000000000000..1309367d08b8952df03451d9d653a1f0b0b26bbf Binary files /dev/null and b/images/coco/057.jpg differ diff --git a/images/coco/058.jpg b/images/coco/058.jpg new file mode 100755 index 0000000000000000000000000000000000000000..32a967693e20ca9d9a8117329d4cce5f589204e2 Binary files /dev/null and b/images/coco/058.jpg differ diff --git a/images/coco/059.jpg b/images/coco/059.jpg new file mode 100755 index 0000000000000000000000000000000000000000..5497a2f65fefff6e8821c5f582bc0ff8ec6c25b1 Binary files /dev/null and b/images/coco/059.jpg differ diff --git a/images/coco/060.jpg b/images/coco/060.jpg new file mode 100755 index 0000000000000000000000000000000000000000..e1420e3628c96c4e3497d27338b8f691dee4d80f Binary files /dev/null and b/images/coco/060.jpg differ diff --git a/images/coco/061.jpg b/images/coco/061.jpg new file mode 100755 index 0000000000000000000000000000000000000000..8105a552f5281a629ff15a74d10bebfc1f704379 Binary files /dev/null and b/images/coco/061.jpg differ diff --git a/images/coco/062.jpg b/images/coco/062.jpg new file mode 100755 index 0000000000000000000000000000000000000000..38cdb3357d35cf7c9ace5cf5c95f364c0c3c7a4f Binary files /dev/null and b/images/coco/062.jpg differ diff --git a/images/coco/063.jpg b/images/coco/063.jpg new file mode 100755 index 0000000000000000000000000000000000000000..410eab482ac2b7ee6eb1ec0efc3530605e540f1f Binary files /dev/null and b/images/coco/063.jpg differ diff --git a/images/coco/064.jpg b/images/coco/064.jpg new file mode 100755 index 0000000000000000000000000000000000000000..2bb18c7c2edc6ba2f36667e83d8f3148fcdb8bf5 Binary files /dev/null and b/images/coco/064.jpg differ diff --git a/images/coco/065.jpg b/images/coco/065.jpg new file mode 100755 index 0000000000000000000000000000000000000000..33900cabc587a7fe24888297aa6d9295663104fe Binary files /dev/null and b/images/coco/065.jpg differ diff --git a/images/coco/066.jpg b/images/coco/066.jpg new file mode 100755 index 0000000000000000000000000000000000000000..8d0df547ea31d3d7e5062dd9ceec94f7a04772ac Binary files /dev/null and b/images/coco/066.jpg differ diff --git a/images/coco/067.jpg b/images/coco/067.jpg new file mode 100755 index 0000000000000000000000000000000000000000..9f978f909c374b58aec661a58636e5a77d4c334f Binary files /dev/null and b/images/coco/067.jpg differ diff --git a/images/coco/068.jpg b/images/coco/068.jpg new file mode 100755 index 0000000000000000000000000000000000000000..98233d53711883ec109b8566dfcbd25da6b354a4 Binary files /dev/null and b/images/coco/068.jpg differ diff --git a/images/coco/069.jpg b/images/coco/069.jpg new file mode 100755 index 0000000000000000000000000000000000000000..32f079f4e5a7c46fc2cf2b707a69662025923402 Binary files /dev/null and b/images/coco/069.jpg differ diff --git a/images/coco/070.jpg b/images/coco/070.jpg new file mode 100755 index 0000000000000000000000000000000000000000..52a51793ab449d32f9393a7807a53b42a511cf93 Binary files /dev/null and b/images/coco/070.jpg differ diff --git a/images/coco/071.jpg b/images/coco/071.jpg new file mode 100755 index 0000000000000000000000000000000000000000..fdef33732008ec94e34fa1c92951d3cf4acc3112 Binary files /dev/null and b/images/coco/071.jpg differ diff --git a/images/coco/072.jpg b/images/coco/072.jpg new file mode 100755 index 0000000000000000000000000000000000000000..f4cb09c4b01a3633e7fd1ea680a7112b41b28abf Binary files /dev/null and b/images/coco/072.jpg differ diff --git a/images/coco/073.jpg b/images/coco/073.jpg new file mode 100755 index 0000000000000000000000000000000000000000..eaf228188e2485875923898f0e75dca8ca3a29c1 Binary files /dev/null and b/images/coco/073.jpg differ diff --git a/images/coco/074.jpg b/images/coco/074.jpg new file mode 100755 index 0000000000000000000000000000000000000000..a8f21f19eff1034b20581dd6068b8666e6616e76 Binary files /dev/null and b/images/coco/074.jpg differ diff --git a/images/coco/075.jpg b/images/coco/075.jpg new file mode 100755 index 0000000000000000000000000000000000000000..bab1c099eb248c00bc5c58f5d77c6cde63f4927a Binary files /dev/null and b/images/coco/075.jpg differ diff --git a/images/coco/076.jpg b/images/coco/076.jpg new file mode 100755 index 0000000000000000000000000000000000000000..6f7d5129d837d8fc23ea3199672a6828f232a4ba Binary files /dev/null and b/images/coco/076.jpg differ diff --git a/images/coco/077.jpg b/images/coco/077.jpg new file mode 100755 index 0000000000000000000000000000000000000000..56b4dc146da19c025cad7069dffe7b7c66f59720 Binary files /dev/null and b/images/coco/077.jpg differ diff --git a/images/coco/078.jpg b/images/coco/078.jpg new file mode 100755 index 0000000000000000000000000000000000000000..39849b1532f4276e33f72afc429edde836674057 Binary files /dev/null and b/images/coco/078.jpg differ diff --git a/images/coco/079.jpg b/images/coco/079.jpg new file mode 100755 index 0000000000000000000000000000000000000000..33c32bb7cc96f8768592d874734b5c52b6775a3e Binary files /dev/null and b/images/coco/079.jpg differ diff --git a/images/dog.png b/images/dog.png new file mode 100644 index 0000000000000000000000000000000000000000..e84dfc8554d344a69afb1fe8c7b8b2997d4e5e11 Binary files /dev/null and b/images/dog.png differ diff --git a/images/fruit.jpg b/images/fruit.jpg new file mode 100644 index 0000000000000000000000000000000000000000..99d107a84272ceb572017df1a041c0c3b235bdaf Binary files /dev/null and b/images/fruit.jpg differ diff --git a/images/horse.png b/images/horse.png new file mode 100644 index 0000000000000000000000000000000000000000..4c41798102e6b523037b66e5bb2407776c1ae4d7 Binary files /dev/null and b/images/horse.png differ diff --git a/images/landscape.jpg b/images/landscape.jpg new file mode 100644 index 0000000000000000000000000000000000000000..91578f776de896f758b78698d71d61ad19894acc Binary files /dev/null and b/images/landscape.jpg differ diff --git a/images/mountain.jpeg b/images/mountain.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..17e8abfb022e4ded0c0ea6edfc6344b1dea2c758 Binary files /dev/null and b/images/mountain.jpeg differ diff --git a/images/owls.jpeg b/images/owls.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..63bfd80a5e973f6b3eab718252779357174cd813 Binary files /dev/null and b/images/owls.jpeg differ diff --git a/images/penguin.jpeg b/images/penguin.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..b76253073d946989177af3cc9928b2d8265d39b0 Binary files /dev/null and b/images/penguin.jpeg differ diff --git a/images/rose.webp b/images/rose.webp new file mode 100644 index 0000000000000000000000000000000000000000..33657ae71e37b955bedd3209ae2a79f58a8c62d3 Binary files /dev/null and b/images/rose.webp differ diff --git a/images/street.jpg b/images/street.jpg new file mode 100644 index 0000000000000000000000000000000000000000..39294f30f004e7f5dafbfacdc75953812b1ca845 Binary files /dev/null and b/images/street.jpg differ diff --git a/tasks/__init__.py b/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd57fbd0eceb1a0079215a008f3392cd8d661b9b --- /dev/null +++ b/tasks/__init__.py @@ -0,0 +1,11 @@ +from .img_cap import image_captioning +from .open_inst import open_instseg +from .open_pano import open_panoseg +from .open_sem import open_semseg +from .ref_cap import referring_captioning +from .ref_in import referring_inpainting +from .ref_seg import referring_segmentation +from .text_ret import text_retrieval +from .reg_ret import region_retrieval +from .ref_in_gpu3 import referring_inpainting_gpt3 +from . import img_cap, open_inst, open_pano, open_sem, ref_cap, ref_in, ref_seg, text_ret \ No newline at end of file diff --git a/tasks/__pycache__/__init__.cpython-38.pyc b/tasks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6f2f965ba49ad89ec03054395fa4f773f8a9aa2 Binary files /dev/null and b/tasks/__pycache__/__init__.cpython-38.pyc differ diff --git a/tasks/__pycache__/img_cap.cpython-38.pyc b/tasks/__pycache__/img_cap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35afcf25d4807ece1cb0da61b2d005226528c24b Binary files /dev/null and b/tasks/__pycache__/img_cap.cpython-38.pyc differ diff --git a/tasks/__pycache__/open_inst.cpython-38.pyc b/tasks/__pycache__/open_inst.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..441a4ce156ef177d59ffe09b731c0490b7736c9b Binary files /dev/null and b/tasks/__pycache__/open_inst.cpython-38.pyc differ diff --git a/tasks/__pycache__/open_pano.cpython-38.pyc b/tasks/__pycache__/open_pano.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c49897c9ca1c434b5b8f81bdefc546c7fac263a1 Binary files /dev/null and b/tasks/__pycache__/open_pano.cpython-38.pyc differ diff --git a/tasks/__pycache__/open_sem.cpython-38.pyc b/tasks/__pycache__/open_sem.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c14aa8a610b8d11d75e2d98f2a901622b065bcf2 Binary files /dev/null and b/tasks/__pycache__/open_sem.cpython-38.pyc differ diff --git a/tasks/__pycache__/ref_cap.cpython-38.pyc b/tasks/__pycache__/ref_cap.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70abba737c0acc98a98bef379d369cb3052fbb21 Binary files /dev/null and b/tasks/__pycache__/ref_cap.cpython-38.pyc differ diff --git a/tasks/__pycache__/ref_in.cpython-38.pyc b/tasks/__pycache__/ref_in.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e13b881c93fee5778746b549e1f31b0db3597ca2 Binary files /dev/null and b/tasks/__pycache__/ref_in.cpython-38.pyc differ diff --git a/tasks/__pycache__/ref_in_gpu3.cpython-38.pyc b/tasks/__pycache__/ref_in_gpu3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a63a7c80c1c483656ef03b45b41414e54922d17 Binary files /dev/null and b/tasks/__pycache__/ref_in_gpu3.cpython-38.pyc differ diff --git a/tasks/__pycache__/ref_seg.cpython-38.pyc b/tasks/__pycache__/ref_seg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32123156fcce574fce7f9793a7b19df9aebf0404 Binary files /dev/null and b/tasks/__pycache__/ref_seg.cpython-38.pyc differ diff --git a/tasks/__pycache__/reg_ret.cpython-38.pyc b/tasks/__pycache__/reg_ret.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a9254cbab72382804bb2747c49b704862827a15 Binary files /dev/null and b/tasks/__pycache__/reg_ret.cpython-38.pyc differ diff --git a/tasks/__pycache__/text_ret.cpython-38.pyc b/tasks/__pycache__/text_ret.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06865b882866b0faa57cf01e2a53026f2bb5254f Binary files /dev/null and b/tasks/__pycache__/text_ret.cpython-38.pyc differ diff --git a/tasks/img_cap.py b/tasks/img_cap.py new file mode 100644 index 0000000000000000000000000000000000000000..2d0edf253896327a6ac5e244ed1b54696c7db9cd --- /dev/null +++ b/tasks/img_cap.py @@ -0,0 +1,55 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import cv2 +import torch +import numpy as np +from PIL import Image +from torchvision import transforms + + +t = [] +t.append(transforms.Resize(224, interpolation=Image.BICUBIC)) +transform = transforms.Compose(t) + +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform_v = transforms.Compose(t) + +def image_captioning(model, image, texts, inpainting_text, *args, **kwargs): + with torch.no_grad(): + image_ori = transform_v(image) + width = image_ori.size[0] + height = image_ori.size[1] + image_ori = np.asarray(image_ori) + + image = transform(image) + image = np.asarray(image) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + + batch_inputs = [{'image': images, 'height': height, 'width': width, 'image_id': 0}] + outputs = model.model.evaluate_captioning(batch_inputs) + text = outputs[-1]['captioning_text'] + + image_ori = image_ori.copy() + cv2.rectangle(image_ori, (0, height-60), (width, height), (0,0,0), -1) + font = cv2.FONT_HERSHEY_DUPLEX + fontScale = 1.2 + thickness = 2 + lineType = 2 + bottomLeftCornerOfText = (10, height-20) + fontColor = [255,255,255] + cv2.putText(image_ori, text, + bottomLeftCornerOfText, + font, + fontScale, + fontColor, + thickness, + lineType) + torch.cuda.empty_cache() + return Image.fromarray(image_ori), text, None + diff --git a/tasks/open_inst.py b/tasks/open_inst.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf1686a0b20c8f54aca9a308afef7cf6dfed166 --- /dev/null +++ b/tasks/open_inst.py @@ -0,0 +1,60 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import torch +import numpy as np +from PIL import Image +from torchvision import transforms +from utils.visualizer import Visualizer +from detectron2.utils.colormap import random_color +from detectron2.data import MetadataCatalog +from detectron2.structures import BitMasks + + +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform = transforms.Compose(t) +metadata = MetadataCatalog.get('ade20k_panoptic_train') + +def open_instseg(model, image, texts, inpainting_text, *args, **kwargs): + thing_classes = [x.strip() for x in texts.split(',')] + thing_colors = [random_color(rgb=True, maximum=255).astype(np.int32).tolist() for _ in range(len(thing_classes))] + thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))} + + MetadataCatalog.get("demo").set( + thing_colors=thing_colors, + thing_classes=thing_classes, + thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id, + ) + + with torch.no_grad(): + model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + ["background"], is_eval=True) + + metadata = MetadataCatalog.get('demo') + model.model.metadata = metadata + model.model.sem_seg_head.num_classes = len(thing_classes) + + image_ori = transform(image) + width = image_ori.size[0] + height = image_ori.size[1] + image = np.asarray(image_ori) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + + batch_inputs = [{'image': images, 'height': height, 'width': width}] + outputs = model.forward(batch_inputs) + visual = Visualizer(image_ori, metadata=metadata) + + inst_seg = outputs[-1]['instances'] + inst_seg.pred_masks = inst_seg.pred_masks.cpu() + inst_seg.pred_boxes = BitMasks(inst_seg.pred_masks > 0).get_bounding_boxes() + demo = visual.draw_instance_predictions(inst_seg) # rgb Image + res = demo.get_image() + + + MetadataCatalog.remove('demo') + torch.cuda.empty_cache() + return Image.fromarray(res), '', None diff --git a/tasks/open_pano.py b/tasks/open_pano.py new file mode 100644 index 0000000000000000000000000000000000000000..48a05f3ec5a0e78568cc4a47c6433b52a4330e8b --- /dev/null +++ b/tasks/open_pano.py @@ -0,0 +1,70 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import torch +import numpy as np +from PIL import Image +from torchvision import transforms +from utils.visualizer import Visualizer +from detectron2.utils.colormap import random_color +from detectron2.data import MetadataCatalog + + +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform = transforms.Compose(t) +metadata = MetadataCatalog.get('ade20k_panoptic_train') + +def open_panoseg(model, image, texts, inpainting_text, *args, **kwargs): + stuff_classes = [x.strip() for x in texts.split(';')[0].replace('stuff:','').split(',')] + thing_classes = [x.strip() for x in texts.split(';')[1].replace('thing:','').split(',')] + thing_colors = [random_color(rgb=True, maximum=255).astype(np.int32).tolist() for _ in range(len(thing_classes))] + stuff_colors = [random_color(rgb=True, maximum=255).astype(np.int32).tolist() for _ in range(len(stuff_classes))] + thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))} + stuff_dataset_id_to_contiguous_id = {x+len(thing_classes):x for x in range(len(stuff_classes))} + + MetadataCatalog.get("demo").set( + thing_colors=thing_colors, + thing_classes=thing_classes, + thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id, + stuff_colors=stuff_colors, + stuff_classes=stuff_classes, + stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id, + ) + model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + stuff_classes + ["background"], is_eval=True) + metadata = MetadataCatalog.get('demo') + model.model.metadata = metadata + model.model.sem_seg_head.num_classes = len(thing_classes + stuff_classes) + + with torch.no_grad(): + image_ori = transform(image) + width = image_ori.size[0] + height = image_ori.size[1] + image = transform(image_ori) + image = np.asarray(image) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + + batch_inputs = [{'image': images, 'height': height, 'width': width}] + outputs = model.forward(batch_inputs) + visual = Visualizer(image_ori, metadata=metadata) + + pano_seg = outputs[-1]['panoptic_seg'][0] + pano_seg_info = outputs[-1]['panoptic_seg'][1] + + for i in range(len(pano_seg_info)): + if pano_seg_info[i]['category_id'] in metadata.thing_dataset_id_to_contiguous_id.keys(): + pano_seg_info[i]['category_id'] = metadata.thing_dataset_id_to_contiguous_id[pano_seg_info[i]['category_id']] + else: + pano_seg_info[i]['isthing'] = False + pano_seg_info[i]['category_id'] = metadata.stuff_dataset_id_to_contiguous_id[pano_seg_info[i]['category_id']] + + demo = visual.draw_panoptic_seg(pano_seg.cpu(), pano_seg_info) # rgb Image + res = demo.get_image() + + MetadataCatalog.remove('demo') + torch.cuda.empty_cache() + return Image.fromarray(res), '', None \ No newline at end of file diff --git a/tasks/open_sem.py b/tasks/open_sem.py new file mode 100644 index 0000000000000000000000000000000000000000..04b95fc9fff82951cf6683a5a2f0632bf30837e4 --- /dev/null +++ b/tasks/open_sem.py @@ -0,0 +1,57 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import os +import cv2 +import torch +import numpy as np +from PIL import Image +from torchvision import transforms +from utils.visualizer import Visualizer +from detectron2.utils.colormap import random_color +from detectron2.data import MetadataCatalog + + +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform = transforms.Compose(t) +metadata = MetadataCatalog.get('ade20k_panoptic_train') + +def open_semseg(model, image, texts, inpainting_text, *args, **kwargs): + stuff_classes = [x.strip() for x in texts.split(',')] + stuff_colors = [random_color(rgb=True, maximum=255).astype(np.int32).tolist() for _ in range(len(stuff_classes))] + stuff_dataset_id_to_contiguous_id = {x:x for x in range(len(stuff_classes))} + + MetadataCatalog.get("demo").set( + stuff_colors=stuff_colors, + stuff_classes=stuff_classes, + stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id, + ) + model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(stuff_classes + ["background"], is_eval=True) + metadata = MetadataCatalog.get('demo') + model.model.metadata = metadata + model.model.sem_seg_head.num_classes = len(stuff_classes) + + with torch.no_grad(): + image_ori = transform(image) + width = image_ori.size[0] + height = image_ori.size[1] + image = transform(image_ori) + image = np.asarray(image) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + + batch_inputs = [{'image': images, 'height': height, 'width': width}] + outputs = model.forward(batch_inputs) + visual = Visualizer(image_ori, metadata=metadata) + + sem_seg = outputs[-1]['sem_seg'].max(0)[1] + demo = visual.draw_sem_seg(sem_seg.cpu(), alpha=0.5) # rgb Image + res = demo.get_image() + + MetadataCatalog.remove('demo') + torch.cuda.empty_cache() + return Image.fromarray(res), '', None \ No newline at end of file diff --git a/tasks/ref_cap.py b/tasks/ref_cap.py new file mode 100644 index 0000000000000000000000000000000000000000..76cd1fd34a038db0fd7a8818ff7a7c764bfb040d --- /dev/null +++ b/tasks/ref_cap.py @@ -0,0 +1,68 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image +from torchvision import transforms +from utils.visualizer import Visualizer +from detectron2.data import MetadataCatalog + +t = [] +t.append(transforms.Resize(224, interpolation=Image.BICUBIC)) +transform_ret = transforms.Compose(t) +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform_grd = transforms.Compose(t) + +metedata = MetadataCatalog.get('coco_2017_train_panoptic') + +def referring_captioning(model, image, texts, inpainting_text, *args, **kwargs): + model_last, model_cap = model + with torch.no_grad(): + image_ori = image + image = transform_grd(image) + width = image.size[0] + height = image.size[1] + image = np.asarray(image) + image_ori_ = image + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + texts_input = [[texts.strip() if texts.endswith('.') else (texts + '.')]] + + batch_inputs = [{'image': images, 'groundings': {'texts':texts_input}, 'height': height, 'width': width}] + outputs = model_last.model.evaluate_grounding(batch_inputs, None) + + grd_mask = (outputs[-1]['grounding_mask'] > 0).float() + grd_mask_ = (1 - F.interpolate(grd_mask[None,], (224, 224), mode='nearest')[0]).bool() + + color = [252/255, 91/255, 129/255] + visual = Visualizer(image_ori_, metadata=metedata) + demo = visual.draw_binary_mask(grd_mask.cpu().numpy()[0], color=color, text=texts) + res = demo.get_image() + + if (1 - grd_mask_.float()).sum() < 5: + torch.cuda.empty_cache() + return Image.fromarray(res), 'n/a', None + + grd_mask_ = grd_mask_ * 0 + image = transform_ret(image_ori) + image_ori = np.asarray(image_ori) + image = np.asarray(image) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + batch_inputs = [{'image': images, 'image_id': 0, 'captioning_mask': grd_mask_}] + + token_text = texts.replace('.','') if texts.endswith('.') else texts + token = model_cap.model.sem_seg_head.predictor.lang_encoder.tokenizer.encode(token_text) + token = torch.tensor(token)[None,:-1] + + outputs = model_cap.model.evaluate_captioning(batch_inputs, extra={'token': token}) + # outputs = model_cap.model.evaluate_captioning(batch_inputs, extra={}) + text = outputs[-1]['captioning_text'] + + torch.cuda.empty_cache() + return Image.fromarray(res), text, None \ No newline at end of file diff --git a/tasks/ref_in.py b/tasks/ref_in.py new file mode 100644 index 0000000000000000000000000000000000000000..d872a993eadae5cbd6c37e821232149fa2e3de16 --- /dev/null +++ b/tasks/ref_in.py @@ -0,0 +1,77 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Jianwei Yang (jianwyan@microsoft.com), Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import torch +import numpy as np +from PIL import Image +from utils.inpainting import pad_image +from torchvision import transforms +from utils.visualizer import Visualizer +from diffusers import StableDiffusionInpaintPipeline +from detectron2.utils.colormap import random_color +from detectron2.data import MetadataCatalog +from scipy import ndimage + + +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform = transforms.Compose(t) +metadata = MetadataCatalog.get('ade20k_panoptic_train') + +pipe = StableDiffusionInpaintPipeline.from_pretrained( + # "stabilityai/stable-diffusion-2-inpainting", + "runwayml/stable-diffusion-inpainting", + revision="fp16", + torch_dtype=torch.float16, +).to("cuda") + +def crop_image(input_image): + crop_w, crop_h = np.floor(np.array(input_image.size) / 64).astype(int) * 64 + im_cropped = Image.fromarray(np.array(input_image)[:crop_h, :crop_w]) + return im_cropped + +def referring_inpainting(model, image, texts, inpainting_text, *args, **kwargs): + model.model.metadata = metadata + texts = [[texts if texts.strip().endswith('.') else (texts.strip() + '.')]] + image_ori = crop_image(transform(image)) + + with torch.no_grad(): + width = image_ori.size[0] + height = image_ori.size[1] + image = np.asarray(image_ori) + image_ori_np = np.asarray(image_ori) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + + batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts}}] + outputs = model.model.evaluate_grounding(batch_inputs, None) + visual = Visualizer(image_ori_np, metadata=metadata) + + grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() + for idx, mask in enumerate(grd_mask): + color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() + demo = visual.draw_binary_mask(mask, color=color, text=texts[idx]) + res = demo.get_image() + + if inpainting_text not in ['no', '']: + # if we want to do inpainting + image_crop = image_ori + struct2 = ndimage.generate_binary_structure(2, 2) + mask_dilated = ndimage.binary_dilation(grd_mask[0], structure=struct2, iterations=3).astype(grd_mask[0].dtype) + mask = Image.fromarray(mask_dilated * 255).convert('RGB') + image_and_mask = { + "image": image_crop, + "mask": mask, + } + width = image_crop.size[0]; height = image_crop.size[1] + images_inpainting = pipe(prompt = inpainting_text.strip(), image=image_and_mask['image'], mask_image=image_and_mask['mask'], height=height, width=width).images[0] + # put images_inpainting back to original image + # image_ori.paste(images_inpainting) + torch.cuda.empty_cache() + return Image.fromarray(res) ,'' , images_inpainting + else: + torch.cuda.empty_cache() + return image_ori, 'text', Image.fromarray(res) \ No newline at end of file diff --git a/tasks/ref_in_gpu3.py b/tasks/ref_in_gpu3.py new file mode 100644 index 0000000000000000000000000000000000000000..c578666beb7c31b10caac15d2c6ba0db4aa2b89b --- /dev/null +++ b/tasks/ref_in_gpu3.py @@ -0,0 +1,103 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Jianwei Yang (jianwyan@microsoft.com) +# -------------------------------------------------------- +import os +import openai +import torch +import numpy as np +from scipy import ndimage +from PIL import Image +from utils.inpainting import pad_image, crop_image +from torchvision import transforms +from utils.visualizer import Visualizer +from diffusers import StableDiffusionInpaintPipeline +from detectron2.utils.colormap import random_color +from detectron2.data import MetadataCatalog + + +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform = transforms.Compose(t) +metadata = MetadataCatalog.get('ade20k_panoptic_train') + +pipe = StableDiffusionInpaintPipeline.from_pretrained( + # "stabilityai/stable-diffusion-2-inpainting", + "runwayml/stable-diffusion-inpainting", + revision="fp16", + torch_dtype=torch.float16, +).to("cuda") + +prompts = [] +prompts.append("remove the person, task: (referring editing), source: [person], target:;") +prompts.append("remove the person in the middle, task: (referring editing), source: [person in the middle], target:;") +prompts.append("remove the dog on the left side, task: (referring editing), source: [dog on the left side], target:;") +prompts.append("change the apple to a pear, task: (referring editing), source: [apple], target: ;") +prompts.append("change the red apple to a green one, task: (referring editing), source: [red apple], target: ;") +prompts.append("replace the dog with a cat, task: (referring editing), source: [dot], target: ;") +prompts.append("replace the red apple with a green one, task: (referring editing), source: [red apple], target: ;") + +def get_gpt3_response(prompt): + openai.api_key = os.getenv("OPENAI_API_KEY") + + response = openai.Completion.create( + model="text-davinci-003", + prompt=prompt, + temperature=0.7, + max_tokens=128, + top_p=1, + frequency_penalty=0, + presence_penalty=0 + ) + + return response + +def referring_inpainting_gpt3(model, image, instruction, *args, **kwargs): + # convert instruction to source and target + print(instruction) + resp = get_gpt3_response(' '.join(prompts) + instruction + ',') + resp_text = resp['choices'][0]['text'] + print(resp_text) + ref_text = resp_text[resp_text.find('[')+1:resp_text.find(']')] + inp_text = resp_text[resp_text.find('<')+1:resp_text.find('>')] + + model.model.metadata = metadata + texts = [[ref_text if ref_text.strip().endswith('.') else (ref_text.strip() + '.')]] + image_ori = crop_image(transform(image)) + + with torch.no_grad(): + width = image_ori.size[0] + height = image_ori.size[1] + image = np.asarray(image_ori) + image_ori_np = np.asarray(image_ori) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + + batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts}}] + outputs = model.model.evaluate_grounding(batch_inputs, None) + visual = Visualizer(image_ori_np, metadata=metadata) + + grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() + for idx, mask in enumerate(grd_mask): + color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() + demo = visual.draw_binary_mask(mask, color=color, text=texts[idx]) + res = demo.get_image() + + if inp_text not in ['no', '']: + image_crop = image_ori + struct2 = ndimage.generate_binary_structure(2, 2) + mask_dilated = ndimage.binary_dilation(grd_mask[0], structure=struct2, iterations=3).astype(grd_mask[0].dtype) + mask = Image.fromarray(mask_dilated * 255).convert('RGB') + image_and_mask = { + "image": image_crop, + "mask": mask, + } + # images_inpainting = inpainting(inpainting_model, image_and_mask, inp_text, ddim_steps, num_samples, scale, seed) + width = image_ori.size[0]; height = image_ori.size[1] + images_inpainting = pipe(prompt = inp_text.strip(), image=image_and_mask['image'], mask_image=image_and_mask['mask'], height=height, width=width).images + torch.cuda.empty_cache() + return Image.fromarray(res), resp_text, images_inpainting[0] + else: + torch.cuda.empty_cache() + return image_ori, resp_text, Image.fromarray(res) \ No newline at end of file diff --git a/tasks/ref_seg.py b/tasks/ref_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a832d8c36b8584ca0784af3c7346c9825e2b6b --- /dev/null +++ b/tasks/ref_seg.py @@ -0,0 +1,46 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import torch +import numpy as np +from PIL import Image +from torchvision import transforms +from utils.visualizer import Visualizer +from detectron2.utils.colormap import random_color +from detectron2.data import MetadataCatalog + + +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform = transforms.Compose(t) +metadata = MetadataCatalog.get('ade20k_panoptic_train') + +def referring_segmentation(model, image, texts, inpainting_text, *args, **kwargs): + model.model.metadata = metadata + texts = texts.strip() + texts = [[text.strip() if text.endswith('.') else (text + '.')] for text in texts.split(',')] + image_ori = transform(image) + + with torch.no_grad(): + width = image_ori.size[0] + height = image_ori.size[1] + image = np.asarray(image_ori) + image_ori_np = np.asarray(image_ori) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + + batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts}}] + outputs = model.model.evaluate_grounding(batch_inputs, None) + visual = Visualizer(image_ori_np, metadata=metadata) + + grd_mask = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() + for idx, mask in enumerate(grd_mask): + color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() + demo = visual.draw_binary_mask(mask, color=color, text=texts[idx]) + res = demo.get_image() + + torch.cuda.empty_cache() + return Image.fromarray(res), '', None \ No newline at end of file diff --git a/tasks/reg_ret.py b/tasks/reg_ret.py new file mode 100644 index 0000000000000000000000000000000000000000..f475cca2c29cc380a7c27d7493fdb227464eb5f6 --- /dev/null +++ b/tasks/reg_ret.py @@ -0,0 +1,72 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import glob +import os +import torch +import numpy as np +from PIL import Image +from torchvision import transforms +from detectron2.data import MetadataCatalog +from utils.visualizer import Visualizer +from xdecoder.language.loss import vl_similarity +from detectron2.utils.colormap import random_color + + +t = [] +t.append(transforms.Resize((224,224), interpolation=Image.BICUBIC)) +transform_ret = transforms.Compose(t) +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform_grd = transforms.Compose(t) +metadata = MetadataCatalog.get('coco_2017_train_panoptic') + +imgs_root = 'images/coco' +img_pths = sorted(glob.glob(os.path.join(imgs_root, '*.jpg'))) +imgs = [Image.open(x).convert('RGB') for x in img_pths] +v_emb = torch.load("v_emb.da") + +def region_retrieval(model, image, texts, inpainting_text, *args, **kwargs): + model_novg, model_seg = model + with torch.no_grad(): + # images = [transform_ret(x) for x in imgs] + # images = [np.asarray(x) for x in imgs] + # images = [torch.from_numpy(x.copy()).permute(2,0,1).cuda() for x in images] + # batch_inputs = [{'image': image, 'image_id': 0} for image in images] + # outputs = model_novg.model.evaluate(batch_inputs) + # v_emb = torch.cat([x['captions'][-1:] for x in outputs]) + # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + # torch.save(v_emb, "v_emb.da") + # exit() + + texts_ = [[x.strip() if x.strip().endswith('.') else (x.strip() + '.')] for x in texts.split(',')] + model_novg.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts_, is_eval=False, name='caption', prompt=False) + t_emb = getattr(model_novg.model.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption')) + temperature = model_novg.model.sem_seg_head.predictor.lang_encoder.logit_scale + + logits = vl_similarity(v_emb, t_emb, temperature) + prob, idx = logits[:,0].softmax(-1).max(0) + image_ori = imgs[idx] + image = transform_grd(image_ori) + width, height = image.size + image = np.asarray(image) + image_ori = np.asarray(image) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + batch_inputs = [{'image': images, 'height': height, 'width': width, 'groundings': {'texts': texts_}}] + model_seg.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts_, is_eval=False, name='caption', prompt=False) + outputs = model_seg.model.evaluate_grounding(batch_inputs, None) + + visual = Visualizer(image_ori, metadata=metadata) + grd_masks = (outputs[0]['grounding_mask'] > 0).float().cpu().numpy() + + for text, mask in zip([x[0] for x in texts_], grd_masks): + color = random_color(rgb=True, maximum=1).astype(np.int32).tolist() + demo = visual.draw_binary_mask(mask, color=color, text=texts, alpha=0.5) + res = demo.get_image() + + torch.cuda.empty_cache() + return Image.fromarray(res), "Selected Image Probability: {:.2f}".format(prob.item()), None \ No newline at end of file diff --git a/tasks/text_ret.py b/tasks/text_ret.py new file mode 100644 index 0000000000000000000000000000000000000000..65d6831ec9b8d60806cc8237bdd5b4366791d1a8 --- /dev/null +++ b/tasks/text_ret.py @@ -0,0 +1,46 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import torch +import numpy as np +from PIL import Image +from torchvision import transforms +from detectron2.data import MetadataCatalog +from xdecoder.language.loss import vl_similarity + + +t = [] +t.append(transforms.Resize(224, interpolation=Image.BICUBIC)) +transform_ret = transforms.Compose(t) +t = [] +t.append(transforms.Resize(512, interpolation=Image.BICUBIC)) +transform_grd = transforms.Compose(t) + +metedata = MetadataCatalog.get('coco_2017_train_panoptic') + +def text_retrieval(model, image, texts, inpainting_text, *args, **kwargs): + out_str = '' + with torch.no_grad(): + image = transform_ret(image) + image = np.asarray(image) + images = torch.from_numpy(image.copy()).permute(2,0,1).cuda() + batch_inputs = [{'image': images, 'image_id': 0}] + outputs = model.model.evaluate(batch_inputs) + v_emb = torch.cat([x['captions'][-1:] for x in outputs]) + v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + + texts = [x.strip() for x in texts.split(',')] + model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, is_eval=False, name='caption', prompt=False) + t_emb = getattr(model.model.sem_seg_head.predictor.lang_encoder, '{}_text_embeddings'.format('caption')) + temperature = model.model.sem_seg_head.predictor.lang_encoder.logit_scale + logits = vl_similarity(v_emb, t_emb, temperature) + topk_prob, topk_idx = logits.softmax(-1)[0].topk(min(5, len(texts))) + + for prob, idx in zip(topk_prob, topk_idx): + out_str += "{}:{:.2f}; ".format(texts[idx.item()], prob.item()) + torch.cuda.empty_cache() + return None, out_str, None \ No newline at end of file diff --git a/utils/Config.py b/utils/Config.py new file mode 100755 index 0000000000000000000000000000000000000000..bc9877e4910a2ccfc2ac0d851c5c87ce1e134450 --- /dev/null +++ b/utils/Config.py @@ -0,0 +1,26 @@ +from fvcore.common.config import CfgNode as _CfgNode + +class CfgNode(_CfgNode): + """ + The same as `fvcore.common.config.CfgNode`, but different in: + + 1. Use unsafe yaml loading by default. + Note that this may lead to arbitrary code execution: you must not + load a config file from untrusted sources before manually inspecting + the content of the file. + 2. Support config versioning. + When attempting to merge an old config, it will convert the old config automatically. + + .. automethod:: clone + .. automethod:: freeze + .. automethod:: defrost + .. automethod:: is_frozen + .. automethod:: load_yaml_with_base + .. automethod:: merge_from_list + .. automethod:: merge_from_other_cfg + """ + + def merge_from_dict(self, dict): + pass + +node = CfgNode() \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d91ddcd0e21b1b23b4f087afa957099819709cd4 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/arguments.cpython-38.pyc b/utils/__pycache__/arguments.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a96df3d6a0068a06f356c91c91cd325643a80417 Binary files /dev/null and b/utils/__pycache__/arguments.cpython-38.pyc differ diff --git a/utils/__pycache__/ddim.cpython-38.pyc b/utils/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3619e86525fa7080e9a39674814a6730c884b8c Binary files /dev/null and b/utils/__pycache__/ddim.cpython-38.pyc differ diff --git a/utils/__pycache__/distributed.cpython-38.pyc b/utils/__pycache__/distributed.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0787a211bf923ce97f11528bd1bd24bbfe829bc Binary files /dev/null and b/utils/__pycache__/distributed.cpython-38.pyc differ diff --git a/utils/__pycache__/inpainting.cpython-38.pyc b/utils/__pycache__/inpainting.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..153a6ffb8a907c5387cd11c33133da7065f5ab31 Binary files /dev/null and b/utils/__pycache__/inpainting.cpython-38.pyc differ diff --git a/utils/__pycache__/misc.cpython-38.pyc b/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a8a82c96d2cd9dc96f54ab62c78812cf2b86764 Binary files /dev/null and b/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/utils/__pycache__/model.cpython-38.pyc b/utils/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..806d89e45f20da0057fcd7de13223bcd33492267 Binary files /dev/null and b/utils/__pycache__/model.cpython-38.pyc differ diff --git a/utils/__pycache__/model_loading.cpython-38.pyc b/utils/__pycache__/model_loading.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4814b4c1db6e19a890360ce16de8d0919ac0352 Binary files /dev/null and b/utils/__pycache__/model_loading.cpython-38.pyc differ diff --git a/utils/__pycache__/util.cpython-38.pyc b/utils/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e579bb8a015066ce4a511d622822edcfc1b66668 Binary files /dev/null and b/utils/__pycache__/util.cpython-38.pyc differ diff --git a/utils/__pycache__/visualizer.cpython-38.pyc b/utils/__pycache__/visualizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a032417d583fb8ef232497f849da4ead8d7bfd9 Binary files /dev/null and b/utils/__pycache__/visualizer.cpython-38.pyc differ diff --git a/utils/arguments.py b/utils/arguments.py new file mode 100755 index 0000000000000000000000000000000000000000..c1a3fa8069e15a287aedd7d15828fa6e23c4fda4 --- /dev/null +++ b/utils/arguments.py @@ -0,0 +1,98 @@ +import yaml +import json +import argparse +import logging + +logger = logging.getLogger(__name__) + + +def load_config_dict_to_opt(opt, config_dict): + """ + Load the key, value pairs from config_dict to opt, overriding existing values in opt + if there is any. + """ + if not isinstance(config_dict, dict): + raise TypeError("Config must be a Python dictionary") + for k, v in config_dict.items(): + k_parts = k.split('.') + pointer = opt + for k_part in k_parts[:-1]: + if k_part not in pointer: + pointer[k_part] = {} + pointer = pointer[k_part] + assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." + ori_value = pointer.get(k_parts[-1]) + pointer[k_parts[-1]] = v + if ori_value: + logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") + + +def load_opt_from_config_files(conf_file): + """ + Load opt from the config files, settings in later files can override those in previous files. + + Args: + conf_files: config file path + + Returns: + dict: a dictionary of opt settings + """ + opt = {} + with open(conf_file, encoding='utf-8') as f: + config_dict = yaml.safe_load(f) + + load_config_dict_to_opt(opt, config_dict) + + return opt + + +def load_opt_command(args): + parser = argparse.ArgumentParser(description='MainzTrain: Pretrain or fine-tune models for NLP tasks.') + parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate') + parser.add_argument('--conf_files', required=True, help='Path(s) to the MainzTrain config file(s).') + parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"": , "..": }. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.') + parser.add_argument('--overrides', help='arguments that used to overide the config file in cmdline', nargs=argparse.REMAINDER) + + cmdline_args = parser.parse_args() if not args else parser.parse_args(args) + + opt = load_opt_from_config_files(cmdline_args.conf_files) + + if cmdline_args.config_overrides: + config_overrides_string = ' '.join(cmdline_args.config_overrides) + logger.warning(f"Command line config overrides: {config_overrides_string}") + config_dict = json.loads(config_overrides_string) + load_config_dict_to_opt(opt, config_dict) + + if cmdline_args.overrides: + assert len(cmdline_args.overrides) % 2 == 0, "overides arguments is not paired, required: key value" + keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)] + vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)] + vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals] + + types = [] + for key in keys: + key = key.split('.') + ele = opt.copy() + while len(key) > 0: + ele = ele[key.pop(0)] + types.append(type(ele)) + + config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)} + load_config_dict_to_opt(opt, config_dict) + + # combine cmdline_args into opt dictionary + for key, val in cmdline_args.__dict__.items(): + if val is not None: + opt[key] = val + + return opt, cmdline_args + + +def save_opt_to_json(opt, conf_file): + with open(conf_file, 'w', encoding='utf-8') as f: + json.dump(opt, f, indent=4) + + +def save_opt_to_yaml(opt, conf_file): + with open(conf_file, 'w', encoding='utf-8') as f: + yaml.dump(opt, f) diff --git a/utils/ddim.py b/utils/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..d6366003eb4107c95cf0cf7bbb653000f716d06c --- /dev/null +++ b/utils/ddim.py @@ -0,0 +1,203 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/utils/distributed.py b/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..521a934de05bca3159bb595cd0ab997ee08dd61a --- /dev/null +++ b/utils/distributed.py @@ -0,0 +1,180 @@ +import os +import time +import torch +import pickle +import torch.distributed as dist + + +def init_distributed(opt): + opt['CUDA'] = opt.get('CUDA', True) and torch.cuda.is_available() + if 'OMPI_COMM_WORLD_SIZE' not in os.environ: + # application was started without MPI + # default to single node with single process + opt['env_info'] = 'no MPI' + opt['world_size'] = 1 + opt['local_size'] = 1 + opt['rank'] = 0 + opt['local_rank'] = 0 + opt['master_address'] = '127.0.0.1' + opt['master_port'] = '8673' + else: + # application was started with MPI + # get MPI parameters + opt['world_size'] = int(os.environ['OMPI_COMM_WORLD_SIZE']) + opt['local_size'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) + opt['rank'] = int(os.environ['OMPI_COMM_WORLD_RANK']) + opt['local_rank'] = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + + # set up device + if not opt['CUDA']: + assert opt['world_size'] == 1, 'multi-GPU training without CUDA is not supported since we use NCCL as communication backend' + opt['device'] = torch.device("cpu") + else: + torch.cuda.set_device(opt['local_rank']) + opt['device'] = torch.device("cuda", opt['local_rank']) + return opt + +def is_main_process(): + rank = 0 + if 'OMPI_COMM_WORLD_SIZE' in os.environ: + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + + return rank == 0 + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + rank = dist.get_rank() + if world_size == 1: + return + + def _send_and_wait(r): + if rank == r: + tensor = torch.tensor(0, device="cuda") + else: + tensor = torch.tensor(1, device="cuda") + dist.broadcast(tensor, r) + while tensor.item() == 1: + time.sleep(1) + + _send_and_wait(0) + # now sync on the main process + _send_and_wait(1) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.IntTensor([tensor.numel()]).to("cuda") + size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def broadcast_data(data): + if not torch.distributed.is_initialized(): + return data + rank = dist.get_rank() + if rank == 0: + data_tensor = torch.tensor(data + [0], device="cuda") + else: + data_tensor = torch.tensor(data + [1], device="cuda") + torch.distributed.broadcast(data_tensor, 0) + while data_tensor.cpu().numpy()[-1] == 1: + time.sleep(1) + + return data_tensor.cpu().numpy().tolist()[:-1] + + +def reduce_sum(tensor): + if get_world_size() <= 1: + return tensor + + tensor = tensor.clone() + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor \ No newline at end of file diff --git a/utils/inpainting.py b/utils/inpainting.py new file mode 100644 index 0000000000000000000000000000000000000000..bd55afd9578d39a1e235d87fa87fefcfeec0ae1e --- /dev/null +++ b/utils/inpainting.py @@ -0,0 +1,177 @@ +import sys +import cv2 +import torch +import numpy as np +import gradio as gr +from PIL import Image +from omegaconf import OmegaConf +from einops import repeat +from imwatermark import WatermarkEncoder +from pathlib import Path + +from .ddim import DDIMSampler +from .util import instantiate_from_config + + +torch.set_grad_enabled(False) + + +def put_watermark(img, wm_encoder=None): + if wm_encoder is not None: + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) + img = wm_encoder.encode(img, 'dwtDct') + img = Image.fromarray(img[:, :, ::-1]) + return img + + +def initialize_model(config, ckpt): + config = OmegaConf.load(config) + model = instantiate_from_config(config.model) + + model.load_state_dict(torch.load(ckpt)["state_dict"], strict=False) + + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + sampler = DDIMSampler(model) + + return sampler + + +def make_batch_sd( + image, + mask, + txt, + device, + num_samples=1): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + batch = { + "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples), + "txt": num_samples * [txt], + "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples), + "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples), + } + return batch + +@torch.no_grad() +def inpaint(sampler, image, mask, prompt, seed, scale, ddim_steps, num_samples=1, w=512, h=512): + device = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + model = sampler.model + + print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...") + wm = "SDV2" + wm_encoder = WatermarkEncoder() + wm_encoder.set_watermark('bytes', wm.encode('utf-8')) + + prng = np.random.RandomState(seed) + start_code = prng.randn(num_samples, 4, h // 8, w // 8) + start_code = torch.from_numpy(start_code).to( + device=device, dtype=torch.float32) + + with torch.no_grad(), \ + torch.autocast("cuda"): + batch = make_batch_sd(image, mask, txt=prompt, + device=device, num_samples=num_samples) + + c = model.cond_stage_model.encode(batch["txt"]) + + c_cat = list() + for ck in model.concat_keys: + cc = batch[ck].float() + if ck != model.masked_image_key: + bchw = [num_samples, 4, h // 8, w // 8] + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = model.get_first_stage_encoding( + model.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + + # cond + cond = {"c_concat": [c_cat], "c_crossattn": [c]} + + # uncond cond + uc_cross = model.get_unconditional_conditioning(num_samples, "") + uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]} + + shape = [model.channels, h // 8, w // 8] + samples_cfg, intermediates = sampler.sample( + ddim_steps, + num_samples, + shape, + cond, + verbose=False, + eta=1.0, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc_full, + x_T=start_code, + ) + x_samples_ddim = model.decode_first_stage(samples_cfg) + + result = torch.clamp((x_samples_ddim + 1.0) / 2.0, + min=0.0, max=1.0) + + result = result.cpu().numpy().transpose(0, 2, 3, 1) * 255 + return [put_watermark(Image.fromarray(img.astype(np.uint8)), wm_encoder) for img in result] + +def pad_image(input_image): + pad_w, pad_h = np.max(((2, 2), np.ceil( + np.array(input_image.size) / 64).astype(int)), axis=0) * 64 - input_image.size + im_padded = Image.fromarray( + np.pad(np.array(input_image), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) + return im_padded + +def crop_image(input_image): + crop_w, crop_h = np.floor(np.array(input_image.size) / 64).astype(int) * 64 + im_cropped = Image.fromarray(np.array(input_image)[:crop_h, :crop_w]) + return im_cropped + +# sampler = initialize_model(sys.argv[1], sys.argv[2]) +@torch.no_grad() +def predict(model, input_image, prompt, ddim_steps, num_samples, scale, seed): + """_summary_ + + Args: + input_image (_type_): dict + - image: PIL.Image. Input image. + - mask: PIL.Image. Mask image. + prompt (_type_): string to be used as prompt. + ddim_steps (_type_): typical 45 + num_samples (_type_): typical 4 + scale (_type_): typical 10.0 Guidance Scale. + seed (_type_): typical 1529160519 + + """ + init_image = input_image["image"].convert("RGB") + init_mask = input_image["mask"].convert("RGB") + image = pad_image(init_image) # resize to integer multiple of 32 + mask = pad_image(init_mask) # resize to integer multiple of 32 + width, height = image.size + print("Inpainting...", width, height) + + result = inpaint( + sampler=model, + image=image, + mask=mask, + prompt=prompt, + seed=seed, + scale=scale, + ddim_steps=ddim_steps, + num_samples=num_samples, + h=height, w=width + ) + + return result \ No newline at end of file diff --git a/utils/misc.py b/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..7b7f187785f8f45ce3d0b069b94ff31150c707ac --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,122 @@ +import math +import numpy as np + +def get_prompt_templates(): + prompt_templates = [ + '{}.', + 'a photo of a {}.', + 'a bad photo of a {}.', + 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', + ] + return prompt_templates + + +def prompt_engineering(classnames, topk=1, suffix='.'): + prompt_templates = get_prompt_templates() + temp_idx = np.random.randint(min(len(prompt_templates), topk)) + + if isinstance(classnames, list): + classname = random.choice(classnames) + else: + classname = classnames + + return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' ')) + +class AverageMeter(object): + """Computes and stores the average and current value.""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1, decay=0): + self.val = val + if decay: + alpha = math.exp(-n / decay) # exponential decay over 100 updates + self.sum = alpha * self.sum + (1 - alpha) * val * n + self.count = alpha * self.count + (1 - alpha) * n + else: + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/utils/model.py b/utils/model.py new file mode 100755 index 0000000000000000000000000000000000000000..c6002070f13c8ba45fa65da9ce907bcc88688a35 --- /dev/null +++ b/utils/model.py @@ -0,0 +1,32 @@ +import logging +import os +import time +import pickle + +import torch +import torch.distributed as dist + +from fvcore.nn import FlopCountAnalysis +from fvcore.nn import flop_count_table +from fvcore.nn import flop_count_str + +logger = logging.getLogger(__name__) + + +NORM_MODULES = [ + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + # NaiveSyncBatchNorm inherits from BatchNorm2d + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, +] + +def register_norm_module(cls): + NORM_MODULES.append(cls) + return cls \ No newline at end of file diff --git a/utils/model_loading.py b/utils/model_loading.py new file mode 100755 index 0000000000000000000000000000000000000000..e679cb7f59f19a3834110ace1f56a1bd077d0049 --- /dev/null +++ b/utils/model_loading.py @@ -0,0 +1,42 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import logging +from utils.distributed import is_main_process +logger = logging.getLogger(__name__) + + +def align_and_update_state_dicts(model_state_dict, ckpt_state_dict): + model_keys = sorted(model_state_dict.keys()) + ckpt_keys = sorted(ckpt_state_dict.keys()) + result_dicts = {} + matched_log = [] + unmatched_log = [] + unloaded_log = [] + for model_key in model_keys: + model_weight = model_state_dict[model_key] + if model_key in ckpt_keys: + ckpt_weight = ckpt_state_dict[model_key] + if model_weight.shape == ckpt_weight.shape: + result_dicts[model_key] = ckpt_weight + ckpt_keys.pop(ckpt_keys.index(model_key)) + matched_log.append("Loaded {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) + else: + unmatched_log.append("*UNMATCHED* {}, Model Shape: {} <-> Ckpt Shape: {}".format(model_key, model_weight.shape, ckpt_weight.shape)) + else: + unloaded_log.append("*UNLOADED* {}, Model Shape: {}".format(model_key, model_weight.shape)) + + if is_main_process(): + for info in matched_log: + logger.info(info) + for info in unloaded_log: + logger.warning(info) + for key in ckpt_keys: + logger.warning("$UNUSED$ {}, Ckpt Shape: {}".format(key, ckpt_state_dict[key].shape)) + for info in unmatched_log: + logger.warning(info) + return result_dicts \ No newline at end of file diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..868c090d4fca05263ee59b7f7e32ef04802674e0 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,283 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! +import importlib + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/utils/visualizer.py b/utils/visualizer.py new file mode 100755 index 0000000000000000000000000000000000000000..afdc2e2ff69f0b36b51c75c41d1893e8d9fb582e --- /dev/null +++ b/utils/visualizer.py @@ -0,0 +1,1278 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import colorsys +import logging +import math +import numpy as np +from enum import Enum, unique +import cv2 +import matplotlib as mpl +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import pycocotools.mask as mask_util +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg +from PIL import Image + +from detectron2.data import MetadataCatalog +from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes +from detectron2.utils.file_io import PathManager + +from detectron2.utils.colormap import random_color + +logger = logging.getLogger(__name__) +__all__ = ["ColorMode", "VisImage", "Visualizer"] + + +_SMALL_OBJECT_AREA_THRESH = 1000 +_LARGE_MASK_AREA_THRESH = 120000 +_OFF_WHITE = (1.0, 1.0, 240.0 / 255) +_BLACK = (0, 0, 0) +_RED = (1.0, 0, 0) + +_KEYPOINT_THRESHOLD = 0.05 + + +@unique +class ColorMode(Enum): + """ + Enum of different color modes to use for instance visualizations. + """ + + IMAGE = 0 + """ + Picks a random color for every instance and overlay segmentations with low opacity. + """ + SEGMENTATION = 1 + """ + Let instances of the same category have similar colors + (from metadata.thing_colors), and overlay them with + high opacity. This provides more attention on the quality of segmentation. + """ + IMAGE_BW = 2 + """ + Same as IMAGE, but convert all areas without masks to gray-scale. + Only available for drawing per-instance mask predictions. + """ + + +class GenericMask: + """ + Attribute: + polygons (list[ndarray]): list[ndarray]: polygons for this mask. + Each ndarray has format [x, y, x, y, ...] + mask (ndarray): a binary mask + """ + + def __init__(self, mask_or_polygons, height, width): + self._mask = self._polygons = self._has_holes = None + self.height = height + self.width = width + + m = mask_or_polygons + if isinstance(m, dict): + # RLEs + assert "counts" in m and "size" in m + if isinstance(m["counts"], list): # uncompressed RLEs + h, w = m["size"] + assert h == height and w == width + m = mask_util.frPyObjects(m, h, w) + self._mask = mask_util.decode(m)[:, :] + return + + if isinstance(m, list): # list[ndarray] + self._polygons = [np.asarray(x).reshape(-1) for x in m] + return + + if isinstance(m, np.ndarray): # assumed to be a binary mask + assert m.shape[1] != 2, m.shape + assert m.shape == ( + height, + width, + ), f"mask shape: {m.shape}, target dims: {height}, {width}" + self._mask = m.astype("uint8") + return + + raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m))) + + @property + def mask(self): + if self._mask is None: + self._mask = self.polygons_to_mask(self._polygons) + return self._mask + + @property + def polygons(self): + if self._polygons is None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + return self._polygons + + @property + def has_holes(self): + if self._has_holes is None: + if self._mask is not None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + else: + self._has_holes = False # if original format is polygon, does not have holes + return self._has_holes + + def mask_to_polygons(self, mask): + # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level + # hierarchy. External contours (boundary) of the object are placed in hierarchy-1. + # Internal contours (holes) are placed in hierarchy-2. + # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours. + mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr + res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + hierarchy = res[-1] + if hierarchy is None: # empty mask + return [], False + has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0 + res = res[-2] + res = [x.flatten() for x in res] + # These coordinates from OpenCV are integers in range [0, W-1 or H-1]. + # We add 0.5 to turn them into real-value coordinate space. A better solution + # would be to first +0.5 and then dilate the returned polygon by 0.5. + res = [x + 0.5 for x in res if len(x) >= 6] + return res, has_holes + + def polygons_to_mask(self, polygons): + rle = mask_util.frPyObjects(polygons, self.height, self.width) + rle = mask_util.merge(rle) + return mask_util.decode(rle)[:, :] + + def area(self): + return self.mask.sum() + + def bbox(self): + p = mask_util.frPyObjects(self.polygons, self.height, self.width) + p = mask_util.merge(p) + bbox = mask_util.toBbox(p) + bbox[2] += bbox[0] + bbox[3] += bbox[1] + return bbox + + +class _PanopticPrediction: + """ + Unify different panoptic annotation/prediction formats + """ + + def __init__(self, panoptic_seg, segments_info, metadata=None): + if segments_info is None: + assert metadata is not None + # If "segments_info" is None, we assume "panoptic_img" is a + # H*W int32 image storing the panoptic_id in the format of + # category_id * label_divisor + instance_id. We reserve -1 for + # VOID label. + label_divisor = metadata.label_divisor + segments_info = [] + for panoptic_label in np.unique(panoptic_seg.numpy()): + if panoptic_label == -1: + # VOID region. + continue + pred_class = panoptic_label // label_divisor + isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() + segments_info.append( + { + "id": int(panoptic_label), + "category_id": int(pred_class), + "isthing": bool(isthing), + } + ) + del metadata + + self._seg = panoptic_seg + + self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info + segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True) + areas = areas.numpy() + sorted_idxs = np.argsort(-areas) + self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs] + self._seg_ids = self._seg_ids.tolist() + for sid, area in zip(self._seg_ids, self._seg_areas): + if sid in self._sinfo: + self._sinfo[sid]["area"] = float(area) + + def non_empty_mask(self): + """ + Returns: + (H, W) array, a mask for all pixels that have a prediction + """ + empty_ids = [] + for id in self._seg_ids: + if id not in self._sinfo: + empty_ids.append(id) + if len(empty_ids) == 0: + return np.zeros(self._seg.shape, dtype=np.uint8) + assert ( + len(empty_ids) == 1 + ), ">1 ids corresponds to no labels. This is currently not supported" + return (self._seg != empty_ids[0]).numpy().astype(np.bool) + + def semantic_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or sinfo["isthing"]: + # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions. + continue + yield (self._seg == sid).numpy().astype(np.bool), sinfo + + def instance_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or not sinfo["isthing"]: + continue + mask = (self._seg == sid).numpy().astype(np.bool) + if mask.sum() > 0: + yield mask, sinfo + + +def _create_text_labels(classes, scores, class_names, is_crowd=None): + """ + Args: + classes (list[int] or None): + scores (list[float] or None): + class_names (list[str] or None): + is_crowd (list[bool] or None): + + Returns: + list[str] or None + """ + labels = None + if classes is not None: + if class_names is not None and len(class_names) > 0: + labels = [class_names[i] for i in classes] + else: + labels = [str(i) for i in classes] + if scores is not None: + if labels is None: + labels = ["{:.0f}%".format(s * 100) for s in scores] + else: + labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)] + if labels is not None and is_crowd is not None: + labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)] + return labels + + +class VisImage: + def __init__(self, img, scale=1.0): + """ + Args: + img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. + scale (float): scale the input image + """ + self.img = img + self.scale = scale + self.width, self.height = img.shape[1], img.shape[0] + self._setup_figure(img) + + def _setup_figure(self, img): + """ + Args: + Same as in :meth:`__init__()`. + + Returns: + fig (matplotlib.pyplot.figure): top level container for all the image plot elements. + ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. + """ + fig = mplfigure.Figure(frameon=False) + self.dpi = fig.get_dpi() + # add a small 1e-2 to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches( + (self.width * self.scale + 1e-2) / self.dpi, + (self.height * self.scale + 1e-2) / self.dpi, + ) + self.canvas = FigureCanvasAgg(fig) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + self.fig = fig + self.ax = ax + self.reset_image(img) + + def reset_image(self, img): + """ + Args: + img: same as in __init__ + """ + img = img.astype("uint8") + self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") + + def save(self, filepath): + """ + Args: + filepath (str): a string that contains the absolute path, including the file name, where + the visualized image will be saved. + """ + self.fig.savefig(filepath) + + def get_image(self): + """ + Returns: + ndarray: + the visualized image of shape (H, W, 3) (RGB) in uint8 type. + The shape is scaled w.r.t the input image using the given `scale` argument. + """ + canvas = self.canvas + s, (width, height) = canvas.print_to_buffer() + # buf = io.BytesIO() # works for cairo backend + # canvas.print_rgba(buf) + # width, height = self.width, self.height + # s = buf.getvalue() + + buffer = np.frombuffer(s, dtype="uint8") + + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype("uint8") + + +class Visualizer: + """ + Visualizer that draws data about detection/segmentation on images. + + It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` + that draw primitive objects to images, as well as high-level wrappers like + `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` + that draw composite data in some pre-defined style. + + Note that the exact visualization style for the high-level wrappers are subject to change. + Style such as color, opacity, label contents, visibility of labels, or even the visibility + of objects themselves (e.g. when the object is too small) may change according + to different heuristics, as long as the results still look visually reasonable. + + To obtain a consistent style, you can implement custom drawing functions with the + abovementioned primitive methods instead. If you need more customized visualization + styles, you can process the data yourself following their format documented in + tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not + intend to satisfy everyone's preference on drawing styles. + + This visualizer focuses on high rendering quality rather than performance. It is not + designed to be used for real-time applications. + """ + + # TODO implement a fast, rasterized version using OpenCV + + def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE): + """ + Args: + img_rgb: a numpy array of shape (H, W, C), where H and W correspond to + the height and width of the image respectively. C is the number of + color channels. The image is required to be in RGB format since that + is a requirement of the Matplotlib library. The image is also expected + to be in the range [0, 255]. + metadata (Metadata): dataset metadata (e.g. class names and colors) + instance_mode (ColorMode): defines one of the pre-defined style for drawing + instances on an image. + """ + self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) + if metadata is None: + metadata = MetadataCatalog.get("__nonexist__") + self.metadata = metadata + self.output = VisImage(self.img, scale=scale) + self.cpu_device = torch.device("cpu") + + # too small texts are useless, therefore clamp to 9 + self._default_font_size = max( + np.sqrt(self.output.height * self.output.width) // 90, 10 // scale + ) + self._default_font_size = 18 + self._instance_mode = instance_mode + self.keypoint_threshold = _KEYPOINT_THRESHOLD + + def draw_instance_predictions(self, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None + labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) + keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None + + keep = (scores > 0.8).cpu() + boxes = boxes[keep] + scores = scores[keep] + classes = np.array(classes) + classes = classes[np.array(keep)] + labels = np.array(labels) + labels = labels[np.array(keep)] + + if predictions.has("pred_masks"): + masks = np.asarray(predictions.pred_masks) + masks = masks[np.array(keep)] + masks = [GenericMask(x, self.output.height, self.output.width) for x in masks] + else: + masks = None + + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + # if self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes + ] + alpha = 0.4 + else: + colors = None + alpha = 0.4 + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image( + self._create_grayscale_image( + (predictions.pred_masks.any(dim=0) > 0).numpy() + if predictions.has("pred_masks") + else None + ) + ) + alpha = 0.3 + + self.overlay_instances( + masks=masks, + boxes=boxes, + labels=labels, + keypoints=keypoints, + assigned_colors=colors, + alpha=alpha, + ) + return self.output + + def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7): + """ + Draw semantic segmentation predictions/labels. + + Args: + sem_seg (Tensor or ndarray): the segmentation of shape (H, W). + Each value is the integer label of the pixel. + area_threshold (int): segments with less than `area_threshold` are not drawn. + alpha (float): the larger it is, the more opaque the segmentations are. + + Returns: + output (VisImage): image object with visualizations. + """ + if isinstance(sem_seg, torch.Tensor): + sem_seg = sem_seg.numpy() + labels, areas = np.unique(sem_seg, return_counts=True) + sorted_idxs = np.argsort(-areas).tolist() + labels = labels[sorted_idxs] + for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels): + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] + except (AttributeError, IndexError): + mask_color = None + + binary_mask = (sem_seg == label).astype(np.uint8) + text = self.metadata.stuff_classes[label] + self.draw_binary_mask( + binary_mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + return self.output + + def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7): + """ + Draw panoptic prediction annotations or results. + + Args: + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each + segment. + segments_info (list[dict] or None): Describe each segment in `panoptic_seg`. + If it is a ``list[dict]``, each dict contains keys "id", "category_id". + If None, category id of each pixel is computed by + ``pixel // metadata.label_divisor``. + area_threshold (int): stuff segments with less than `area_threshold` are not drawn. + + Returns: + output (VisImage): image object with visualizations. + """ + pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask())) + + # draw mask for all semantic segments first i.e. "stuff" + for mask, sinfo in pred.semantic_masks(): + category_idx = sinfo["category_id"] + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] + except AttributeError: + mask_color = None + + text = self.metadata.stuff_classes[category_idx] + self.draw_binary_mask( + mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + + # draw mask for all instances second + all_instances = list(pred.instance_masks()) + if len(all_instances) == 0: + return self.output + masks, sinfo = list(zip(*all_instances)) + category_ids = [x["category_id"] for x in sinfo] + + try: + scores = [x["score"] for x in sinfo] + except KeyError: + scores = None + labels = _create_text_labels( + category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo] + ) + + try: + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids + ] + except AttributeError: + colors = None + self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha) + + return self.output + + draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility + + def draw_dataset_dict(self, dic): + """ + Draw annotations/segmentaions in Detectron2 Dataset format. + + Args: + dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format. + + Returns: + output (VisImage): image object with visualizations. + """ + annos = dic.get("annotations", None) + if annos: + if "segmentation" in annos[0]: + masks = [x["segmentation"] for x in annos] + else: + masks = None + if "keypoints" in annos[0]: + keypts = [x["keypoints"] for x in annos] + keypts = np.array(keypts).reshape(len(annos), -1, 3) + else: + keypts = None + + boxes = [ + BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS) + if len(x["bbox"]) == 4 + else x["bbox"] + for x in annos + ] + + colors = None + category_ids = [x["category_id"] for x in annos] + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in category_ids + ] + names = self.metadata.get("thing_classes", None) + labels = _create_text_labels( + category_ids, + scores=None, + class_names=names, + is_crowd=[x.get("iscrowd", 0) for x in annos], + ) + self.overlay_instances( + labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors + ) + + sem_seg = dic.get("sem_seg", None) + if sem_seg is None and "sem_seg_file_name" in dic: + with PathManager.open(dic["sem_seg_file_name"], "rb") as f: + sem_seg = Image.open(f) + sem_seg = np.asarray(sem_seg, dtype="uint8") + if sem_seg is not None: + self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4) + + pan_seg = dic.get("pan_seg", None) + if pan_seg is None and "pan_seg_file_name" in dic: + with PathManager.open(dic["pan_seg_file_name"], "rb") as f: + pan_seg = Image.open(f) + pan_seg = np.asarray(pan_seg) + from panopticapi.utils import rgb2id + + pan_seg = rgb2id(pan_seg) + if pan_seg is not None: + segments_info = dic["segments_info"] + pan_seg = torch.tensor(pan_seg) + self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7) + return self.output + + def overlay_instances( + self, + *, + boxes=None, + labels=None, + masks=None, + keypoints=None, + assigned_colors=None, + alpha=0.5, + ): + """ + Args: + boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, + or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, + or a :class:`RotatedBoxes`, + or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image, + labels (list[str]): the text to be displayed for each instance. + masks (masks-like object): Supported types are: + + * :class:`detectron2.structures.PolygonMasks`, + :class:`detectron2.structures.BitMasks`. + * list[list[ndarray]]: contains the segmentation masks for all objects in one image. + The first level of the list corresponds to individual instances. The second + level to all the polygon that compose the instance, and the third level + to the polygon coordinates. The third level should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + * list[ndarray]: each ndarray is a binary mask of shape (H, W). + * list[dict]: each dict is a COCO-style RLE. + keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), + where the N is the number of instances and K is the number of keypoints. + The last dimension corresponds to (x, y, visibility or score). + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = 0 + if boxes is not None: + boxes = self._convert_boxes(boxes) + num_instances = len(boxes) + if masks is not None: + masks = self._convert_masks(masks) + if num_instances: + assert len(masks) == num_instances + else: + num_instances = len(masks) + if keypoints is not None: + if num_instances: + assert len(keypoints) == num_instances + else: + num_instances = len(keypoints) + keypoints = self._convert_keypoints(keypoints) + if labels is not None: + assert len(labels) == num_instances + if assigned_colors is None: + assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] + if num_instances == 0: + return self.output + if boxes is not None and boxes.shape[1] == 5: + return self.overlay_rotated_instances( + boxes=boxes, labels=labels, assigned_colors=assigned_colors + ) + + # Display in largest to smallest order to reduce occlusion. + areas = None + if boxes is not None: + areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) + elif masks is not None: + areas = np.asarray([x.area() for x in masks]) + + if areas is not None: + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] if boxes is not None else None + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None + assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] + keypoints = keypoints[sorted_idxs] if keypoints is not None else None + + for i in range(num_instances): + color = assigned_colors[i] + if boxes is not None: + self.draw_box(boxes[i], edge_color=color) + + if masks is not None: + for segment in masks[i].polygons: + self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) + + if labels is not None: + # first get a box + if boxes is not None: + x0, y0, x1, y1 = boxes[i] + text_pos = (x0, y0) # if drawing boxes, put text on the box corner. + horiz_align = "left" + elif masks is not None: + # skip small mask without polygon + if len(masks[i].polygons) == 0: + continue + + x0, y0, x1, y1 = masks[i].bbox() + + # draw text in the center (defined by median) when box is not drawn + # median is less sensitive to outliers. + text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] + horiz_align = "center" + else: + continue # drawing the box confidence for keypoints isn't very useful. + # for small objects, draw text at the side to avoid occlusion + instance_area = (y1 - y0) * (x1 - x0) + if ( + instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale + or y1 - y0 < 40 * self.output.scale + ): + if y1 >= self.output.height - 5: + text_pos = (x1, y0) + else: + text_pos = (x0, y1) + + height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + * 0.5 + * self._default_font_size + ) + self.draw_text( + labels[i], + text_pos, + color=lighter_color, + horizontal_alignment=horiz_align, + font_size=font_size, + ) + + # draw keypoints + if keypoints is not None: + for keypoints_per_instance in keypoints: + self.draw_and_connect_keypoints(keypoints_per_instance) + + return self.output + + def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None): + """ + Args: + boxes (ndarray): an Nx5 numpy array of + (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image. + labels (list[str]): the text to be displayed for each instance. + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = len(boxes) + + if assigned_colors is None: + assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] + if num_instances == 0: + return self.output + + # Display in largest to smallest order to reduce occlusion. + if boxes is not None: + areas = boxes[:, 2] * boxes[:, 3] + + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + colors = [assigned_colors[idx] for idx in sorted_idxs] + + for i in range(num_instances): + self.draw_rotated_box_with_label( + boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None + ) + + return self.output + + def draw_and_connect_keypoints(self, keypoints): + """ + Draws keypoints of an instance and follows the rules for keypoint connections + to draw lines between appropriate keypoints. This follows color heuristics for + line color. + + Args: + keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints + and the last dimension corresponds to (x, y, probability). + + Returns: + output (VisImage): image object with visualizations. + """ + visible = {} + keypoint_names = self.metadata.get("keypoint_names") + for idx, keypoint in enumerate(keypoints): + + # draw keypoint + x, y, prob = keypoint + if prob > self.keypoint_threshold: + self.draw_circle((x, y), color=_RED) + if keypoint_names: + keypoint_name = keypoint_names[idx] + visible[keypoint_name] = (x, y) + + if self.metadata.get("keypoint_connection_rules"): + for kp0, kp1, color in self.metadata.keypoint_connection_rules: + if kp0 in visible and kp1 in visible: + x0, y0 = visible[kp0] + x1, y1 = visible[kp1] + color = tuple(x / 255.0 for x in color) + self.draw_line([x0, x1], [y0, y1], color=color) + + # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip + # Note that this strategy is specific to person keypoints. + # For other keypoints, it should just do nothing + try: + ls_x, ls_y = visible["left_shoulder"] + rs_x, rs_y = visible["right_shoulder"] + mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 + except KeyError: + pass + else: + # draw line from nose to mid-shoulder + nose_x, nose_y = visible.get("nose", (None, None)) + if nose_x is not None: + self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED) + + try: + # draw line from mid-shoulder to mid-hip + lh_x, lh_y = visible["left_hip"] + rh_x, rh_y = visible["right_hip"] + except KeyError: + pass + else: + mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 + self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED) + return self.output + + """ + Primitive drawing functions: + """ + + def draw_text( + self, + text, + position, + *, + font_size=None, + color="g", + horizontal_alignment="center", + rotation=0, + ): + """ + Args: + text (str): class label + position (tuple): a tuple of the x and y coordinates to place text on image. + font_size (int, optional): font of the text. If not provided, a font size + proportional to the image width is calculated and used. + color: color of the text. Refer to `matplotlib.colors` for full list + of formats that are accepted. + horizontal_alignment (str): see `matplotlib.text.Text` + rotation: rotation angle in degrees CCW + + Returns: + output (VisImage): image object with text drawn. + """ + if not font_size: + font_size = self._default_font_size + + # since the text background is dark, we don't want the text to be dark + color = np.maximum(list(mplc.to_rgb(color)), 0.2) + color[np.argmax(color)] = max(0.8, np.max(color)) + + x, y = position + self.output.ax.text( + x, + y, + text, + size=font_size * self.output.scale, + family="sans-serif", + bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + zorder=10, + rotation=rotation, + ) + return self.output + + def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): + """ + Args: + box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 + are the coordinates of the image's top left corner. x1 and y1 are the + coordinates of the image's bottom right corner. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + + Returns: + output (VisImage): image object with box drawn. + """ + x0, y0, x1, y1 = box_coord + width = x1 - x0 + height = y1 - y0 + + linewidth = max(self._default_font_size / 4, 1) + + self.output.ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=edge_color, + linewidth=linewidth * self.output.scale, + alpha=alpha, + linestyle=line_style, + ) + ) + return self.output + + def draw_rotated_box_with_label( + self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None + ): + """ + Draw a rotated box with label on its top-left corner. + + Args: + rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle), + where cnt_x and cnt_y are the center coordinates of the box. + w and h are the width and height of the box. angle represents how + many degrees the box is rotated CCW with regard to the 0-degree box. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + label (string): label for rotated box. It will not be rendered when set to None. + + Returns: + output (VisImage): image object with box drawn. + """ + cnt_x, cnt_y, w, h, angle = rotated_box + area = w * h + # use thinner lines when the box is small + linewidth = self._default_font_size / ( + 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3 + ) + + theta = angle * math.pi / 180.0 + c = math.cos(theta) + s = math.sin(theta) + rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)] + # x: left->right ; y: top->down + rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect] + for k in range(4): + j = (k + 1) % 4 + self.draw_line( + [rotated_rect[k][0], rotated_rect[j][0]], + [rotated_rect[k][1], rotated_rect[j][1]], + color=edge_color, + linestyle="--" if k == 1 else line_style, + linewidth=linewidth, + ) + + if label is not None: + text_pos = rotated_rect[1] # topleft corner + + height_ratio = h / np.sqrt(self.output.height * self.output.width) + label_color = self._change_color_brightness(edge_color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size + ) + self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle) + + return self.output + + def draw_circle(self, circle_coord, color, radius=3): + """ + Args: + circle_coord (list(int) or tuple(int)): contains the x and y coordinates + of the center of the circle. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + radius (int): radius of the circle. + + Returns: + output (VisImage): image object with box drawn. + """ + x, y = circle_coord + self.output.ax.add_patch( + mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color) + ) + return self.output + + def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None): + """ + Args: + x_data (list[int]): a list containing x values of all the points being drawn. + Length of list should match the length of y_data. + y_data (list[int]): a list containing y values of all the points being drawn. + Length of list should match the length of x_data. + color: color of the line. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + linestyle: style of the line. Refer to `matplotlib.lines.Line2D` + for a full list of formats that are accepted. + linewidth (float or None): width of the line. When it's None, + a default value will be computed and used. + + Returns: + output (VisImage): image object with line drawn. + """ + if linewidth is None: + linewidth = self._default_font_size / 3 + linewidth = max(linewidth, 1) + self.output.ax.add_line( + mpl.lines.Line2D( + x_data, + y_data, + linewidth=linewidth * self.output.scale, + color=color, + linestyle=linestyle, + ) + ) + return self.output + + def draw_binary_mask( + self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.7, area_threshold=10 + ): + """ + Args: + binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and + W is the image width. Each value in the array is either a 0 or 1 value of uint8 + type. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + area_threshold (float): a connected component smaller than this area will not be shown. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + has_valid_segment = False + binary_mask = binary_mask.astype("uint8") # opencv needs uint8 + mask = GenericMask(binary_mask, self.output.height, self.output.width) + shape2d = (binary_mask.shape[0], binary_mask.shape[1]) + + if not mask.has_holes: + # draw polygons for regular masks + for segment in mask.polygons: + area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1])) + if area < (area_threshold or 0): + continue + has_valid_segment = True + segment = segment.reshape(-1, 2) + self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha) + else: + # TODO: Use Path/PathPatch to draw vector graphics: + # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha + has_valid_segment = True + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if text is not None and has_valid_segment: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5): + """ + Args: + soft_mask (ndarray): float array of shape (H, W), each value in [0, 1]. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + shape2d = (soft_mask.shape[0], soft_mask.shape[1]) + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = soft_mask * alpha + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if text is not None: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + binary_mask = (soft_mask > 0.5).astype("uint8") + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_polygon(self, segment, color, edge_color=None, alpha=0.5): + """ + Args: + segment: numpy array of shape Nx2, containing all the points in the polygon. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. If not provided, a darker shade + of the polygon color will be used instead. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with polygon drawn. + """ + if edge_color is None: + # make edge color darker than the polygon color + if alpha > 0.8: + edge_color = self._change_color_brightness(color, brightness_factor=-0.7) + else: + edge_color = color + edge_color = mplc.to_rgb(edge_color) + (1,) + + polygon = mpl.patches.Polygon( + segment, + fill=True, + facecolor=mplc.to_rgb(color) + (alpha,), + edgecolor=edge_color, + linewidth=max(self._default_font_size // 15 * self.output.scale, 1), + ) + self.output.ax.add_patch(polygon) + return self.output + + """ + Internal methods: + """ + + def _jitter(self, color): + """ + Randomly modifies given color to produce a slightly different color than the color given. + + Args: + color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color + picked. The values in the list are in the [0.0, 1.0] range. + + Returns: + jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the + color after being jittered. The values in the list are in the [0.0, 1.0] range. + """ + color = mplc.to_rgb(color) + # np.random.seed(0) + vec = np.random.rand(3) + # better to do it in another color space + vec = vec / np.linalg.norm(vec) * 0.5 + res = np.clip(vec + color, 0, 1) + return tuple(res) + + def _create_grayscale_image(self, mask=None): + """ + Create a grayscale version of the original image. + The colors in masked area, if given, will be kept. + """ + img_bw = self.img.astype("f4").mean(axis=2) + img_bw = np.stack([img_bw] * 3, axis=2) + if mask is not None: + img_bw[mask] = self.img[mask] + return img_bw + + def _change_color_brightness(self, color, brightness_factor): + """ + Depending on the brightness_factor, gives a lighter or darker color i.e. a color with + less or more saturation than the original color. + + Args: + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of + 0 will correspond to no change, a factor in [-1.0, 0) range will result in + a darker color and a factor in (0, 1.0] range will result in a lighter color. + + Returns: + modified_color (tuple[double]): a tuple containing the RGB values of the + modified color. Each value in the tuple is in the [0.0, 1.0] range. + """ + assert brightness_factor >= -1.0 and brightness_factor <= 1.0 + color = mplc.to_rgb(color) + polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) + modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) + modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness + modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness + modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) + return modified_color + + def _convert_boxes(self, boxes): + """ + Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension. + """ + if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes): + return boxes.tensor.detach().numpy() + else: + return np.asarray(boxes) + + def _convert_masks(self, masks_or_polygons): + """ + Convert different format of masks or polygons to a tuple of masks and polygons. + + Returns: + list[GenericMask]: + """ + + m = masks_or_polygons + if isinstance(m, PolygonMasks): + m = m.polygons + if isinstance(m, BitMasks): + m = m.tensor.numpy() + if isinstance(m, torch.Tensor): + m = m.numpy() + ret = [] + for x in m: + if isinstance(x, GenericMask): + ret.append(x) + else: + ret.append(GenericMask(x, self.output.height, self.output.width)) + return ret + + def _draw_text_in_mask(self, binary_mask, text, color): + """ + Find proper places to draw text given a binary mask. + """ + # TODO sometimes drawn on wrong objects. the heuristics here can improve. + _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8) + if stats[1:, -1].size == 0: + return + largest_component_id = np.argmax(stats[1:, -1]) + 1 + + # draw text on the largest component, as well as other very large components. + for cid in range(1, _num_cc): + if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: + # median is more stable than centroid + # center = centroids[largest_component_id] + center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] + self.draw_text(text, center, color=color) + + def _convert_keypoints(self, keypoints): + if isinstance(keypoints, Keypoints): + keypoints = keypoints.tensor + keypoints = np.asarray(keypoints) + return keypoints + + def get_output(self): + """ + Returns: + output (VisImage): the image output containing the visualizations added + to the image. + """ + return self.output \ No newline at end of file diff --git a/v_emb.da b/v_emb.da new file mode 100644 index 0000000000000000000000000000000000000000..6f2af1eb201a916fd973472c1b1cf7ba49f3a6d3 Binary files /dev/null and b/v_emb.da differ diff --git a/xdecoder/BaseModel.py b/xdecoder/BaseModel.py new file mode 100755 index 0000000000000000000000000000000000000000..cd0803f43d53554db6e718302ef28aa573bc05a5 --- /dev/null +++ b/xdecoder/BaseModel.py @@ -0,0 +1,37 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import os +import logging + +import torch +import torch.nn as nn + +from utils.model_loading import align_and_update_state_dicts + +logger = logging.getLogger(__name__) + + +class BaseModel(nn.Module): + def __init__(self, opt, module: nn.Module): + super(BaseModel, self).__init__() + self.opt = opt + self.model = module + + def forward(self, *inputs, **kwargs): + outputs = self.model(*inputs, **kwargs) + return outputs + + def save_pretrained(self, save_dir): + save_path = os.path.join(save_dir, 'model_state_dict.pt') + torch.save(self.model.state_dict(), save_path) + + def from_pretrained(self, load_path): + state_dict = torch.load(load_path, map_location=self.opt['device']) + state_dict = align_and_update_state_dicts(self.model.state_dict(), state_dict) + self.model.load_state_dict(state_dict, strict=False) + return self \ No newline at end of file diff --git a/xdecoder/__init__.py b/xdecoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..47a369b587a83ac2691a90e583a4bbb5c0cb23e0 --- /dev/null +++ b/xdecoder/__init__.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .architectures import build_model \ No newline at end of file diff --git a/xdecoder/__pycache__/BaseModel.cpython-38.pyc b/xdecoder/__pycache__/BaseModel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0c87b95fe7c0f4f285509fbc5e0516136b9ef6e Binary files /dev/null and b/xdecoder/__pycache__/BaseModel.cpython-38.pyc differ diff --git a/xdecoder/__pycache__/__init__.cpython-38.pyc b/xdecoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6a928826f56e48b972a1cab1e315bee5d82b02f Binary files /dev/null and b/xdecoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/architectures/__init__.py b/xdecoder/architectures/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..7831efa29c9427175212c79734aa06b88651c53f --- /dev/null +++ b/xdecoder/architectures/__init__.py @@ -0,0 +1,2 @@ +from .xdecoder_model import * +from .build import build_model \ No newline at end of file diff --git a/xdecoder/architectures/__pycache__/__init__.cpython-38.pyc b/xdecoder/architectures/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3d45ff06a43f2e4a6f53c36a3425cb391639524 Binary files /dev/null and b/xdecoder/architectures/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/architectures/__pycache__/build.cpython-38.pyc b/xdecoder/architectures/__pycache__/build.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca54f4ebc5baf18c82aa2537bcf5ebe7a232966f Binary files /dev/null and b/xdecoder/architectures/__pycache__/build.cpython-38.pyc differ diff --git a/xdecoder/architectures/__pycache__/registry.cpython-38.pyc b/xdecoder/architectures/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31461626dc3a6d0bc30d1e50f2db44ca41160356 Binary files /dev/null and b/xdecoder/architectures/__pycache__/registry.cpython-38.pyc differ diff --git a/xdecoder/architectures/__pycache__/xdecoder_model.cpython-38.pyc b/xdecoder/architectures/__pycache__/xdecoder_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..392660f44d7044b6da4cf5f8b268d6b184a4209a Binary files /dev/null and b/xdecoder/architectures/__pycache__/xdecoder_model.cpython-38.pyc differ diff --git a/xdecoder/architectures/build.py b/xdecoder/architectures/build.py new file mode 100755 index 0000000000000000000000000000000000000000..c94201fe7ec172040ac092b7efe7d0a7b0adbd47 --- /dev/null +++ b/xdecoder/architectures/build.py @@ -0,0 +1,10 @@ +from .registry import model_entrypoints +from .registry import is_model + +def build_model(config, **kwargs): + model_name = config['MODEL']['NAME'] + + if not is_model(model_name): + raise ValueError(f'Unkown model: {model_name}') + + return model_entrypoints(model_name)(config, **kwargs) \ No newline at end of file diff --git a/xdecoder/architectures/registry.py b/xdecoder/architectures/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..940e4560f7d052aed4915187410266ab5a4cb4d0 --- /dev/null +++ b/xdecoder/architectures/registry.py @@ -0,0 +1,13 @@ +_model_entrypoints = {} + +def register_model(fn): + module_name_split = fn.__module__.split('.') + model_name = module_name_split[-1] + _model_entrypoints[model_name] = fn + return fn + +def model_entrypoints(model_name): + return _model_entrypoints[model_name] + +def is_model(model_name): + return model_name in _model_entrypoints \ No newline at end of file diff --git a/xdecoder/architectures/xdecoder_model.py b/xdecoder/architectures/xdecoder_model.py new file mode 100755 index 0000000000000000000000000000000000000000..65ee51e84247861a4cd6690248e893d1d9c15ad3 --- /dev/null +++ b/xdecoder/architectures/xdecoder_model.py @@ -0,0 +1,622 @@ +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +import random +from typing import Tuple +from unicodedata import name + +import torch +from torch import nn +from torch.nn import functional as F +import numpy as np + +from .registry import register_model +from ..utils import configurable +from ..backbone import build_backbone, Backbone +from ..body import build_xdecoder_head +from ..modules import sem_seg_postprocess, bbox_postprocess +from ..language import build_language_encoder +from ..language.loss import vl_similarity + +from timm.models.layers import trunc_normal_ +from nltk.stem.lancaster import LancasterStemmer +from detectron2.structures import Boxes, ImageList, Instances, BitMasks, BoxMode +from detectron2.utils.memory import retry_if_cuda_oom +from detectron2.data import MetadataCatalog +from utils.misc import prompt_engineering + +st = LancasterStemmer() + + +class X_Decoder_Model(nn.Module): + @configurable + def __init__( + self, + *, + backbone: Backbone, + sem_seg_head: nn.Module, + criterion: nn.Module, + losses: dict, + num_queries: int, + object_mask_threshold: float, + overlap_threshold: float, + metadata, + task_switch: dict, + phrase_prob: float, + size_divisibility: int, + sem_seg_postprocess_before_inference: bool, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + # inference + semantic_on: bool, + panoptic_on: bool, + instance_on: bool, + test_topk_per_image: int, + train_dataset_name: str, + retrieval_emsemble: bool, + backbone_dim: int, + dim_proj: int, + ): + super().__init__() + self.backbone = backbone + self.sem_seg_head = sem_seg_head + self.criterion = criterion + self.losses = losses + self.num_queries = num_queries + self.overlap_threshold = overlap_threshold + self.object_mask_threshold = object_mask_threshold + self.metadata = metadata + if size_divisibility < 0: + # use backbone size_divisibility if not set + size_divisibility = self.backbone.size_divisibility + self.size_divisibility = size_divisibility + self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + # additional args + self.semantic_on = semantic_on + self.instance_on = instance_on + self.panoptic_on = panoptic_on + + # caption argument + self.task_switch = task_switch + self.phrase_prob = phrase_prob + + self.test_topk_per_image = test_topk_per_image + self.train_class_names = None + + self.retrieval_emsemble = retrieval_emsemble + # backbone itc loss + if task_switch['retrieval'] and retrieval_emsemble: + self.backbone_proj = nn.Parameter(torch.empty(backbone_dim, dim_proj)) + trunc_normal_(self.backbone_proj, std=.02) + + if not self.semantic_on: + assert self.sem_seg_postprocess_before_inference + + @classmethod + def from_config(cls, cfg): + enc_cfg = cfg['MODEL']['ENCODER'] + dec_cfg = cfg['MODEL']['DECODER'] + + task_switch = {'bbox': dec_cfg.get('DETECTION', False), + 'mask': dec_cfg.get('MASK', True), + 'caption': dec_cfg['CAPTION'].get('ENABLED', False), + 'captioning': dec_cfg['CAPTIONING'].get('ENABLED', False), + 'retrieval': dec_cfg['RETRIEVAL'].get('ENABLED', False), + 'grounding': dec_cfg['GROUNDING'].get('ENABLED', False)} + + # build model + extra = {'task_switch': task_switch} + backbone = build_backbone(cfg) + lang_encoder = build_language_encoder(cfg) + sem_seg_head = build_xdecoder_head(cfg, backbone.output_shape(), lang_encoder, extra) + + # Training Settings. + loss_weights = {} + matcher = None + losses = {} + weight_dict = {} + grd_weight = {} + top_x_layers = {} + criterion = None + train_dataset_name = None + phrase_prob = None + # Loss parameters: + deep_supervision = None + no_object_weight = None + + return { + "backbone": backbone, + "sem_seg_head": sem_seg_head, + "criterion": criterion, + "losses": losses, + "num_queries": dec_cfg['NUM_OBJECT_QUERIES'], + "object_mask_threshold": dec_cfg['TEST']['OBJECT_MASK_THRESHOLD'], + "overlap_threshold": dec_cfg['TEST']['OVERLAP_THRESHOLD'], + "metadata": None, + "size_divisibility": dec_cfg['SIZE_DIVISIBILITY'], + "sem_seg_postprocess_before_inference": ( + dec_cfg['TEST']['SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE'] + or dec_cfg['TEST']['PANOPTIC_ON'] + or dec_cfg['TEST']['INSTANCE_ON'] + ), + "pixel_mean": cfg['INPUT']['PIXEL_MEAN'], + "pixel_std": cfg['INPUT']['PIXEL_STD'], + "task_switch": task_switch, + "phrase_prob": phrase_prob, + # inference + "semantic_on": dec_cfg['TEST']['SEMANTIC_ON'], + "instance_on": dec_cfg['TEST']['INSTANCE_ON'], + "panoptic_on": dec_cfg['TEST']['PANOPTIC_ON'], + "test_topk_per_image": cfg['MODEL']['DECODER']['TEST']['DETECTIONS_PER_IMAGE'], + "train_dataset_name": train_dataset_name, + "retrieval_emsemble": dec_cfg['RETRIEVAL']['ENSEMBLE'], + "backbone_dim": cfg['MODEL']['BACKBONE_DIM'], + "dim_proj": cfg['MODEL']['DIM_PROJ'], + } + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs, mode=None): + if self.training: + assert False, "Not support trianing mode." + else: + if mode == 'retrieval': + return self.evaluate_retrieval(batched_inputs) + elif mode == 'captioning': + return self.evaluate_captioning(batched_inputs) + elif mode == 'classification': + return self.evaluate_classification(batched_inputs) + elif mode in ['grounding_phrasecut', 'grounding_refcoco']: + return self.evaluate_grounding(batched_inputs, mode) + else: + return self.evaluate(batched_inputs) + + def evaluate(self, batched_inputs): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + + images = ImageList.from_tensors(images, self.size_divisibility) + img_bs = images.tensor.shape[0] + + targets = targets_grounding = queries_grounding = None + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features, target_queries=queries_grounding) + + mask_cls_results = outputs["pred_logits"] + mask_pred_results = outputs["pred_masks"] + box_pred_results = outputs["pred_boxes"] if self.task_switch['bbox'] else [None for i in range(len(mask_pred_results))] + caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))] + + # upsample masks + mask_pred_results = F.interpolate( + mask_pred_results, + size=(images.tensor.shape[-2], images.tensor.shape[-1]), + mode="bilinear", + align_corners=False, + ) + + input_size = mask_pred_results.shape[-2:] + keep_sem_bgd = self.metadata.keep_sem_bgd if hasattr(self.metadata, 'keep_sem_bgd') else False + del outputs + + processed_results = [] + for mask_cls_result, mask_pred_result, box_pred_result, caption_pred_result, input_per_image, image_size in zip( + mask_cls_results, mask_pred_results, box_pred_results, caption_pred_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + processed_results.append({}) + + if self.sem_seg_postprocess_before_inference: + mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( + mask_pred_result, image_size, height, width + ) + mask_cls_result = mask_cls_result.to(mask_pred_result) + + # semantic segmentation inference + if self.semantic_on: + r = retry_if_cuda_oom(self.semantic_inference)(mask_cls_result, mask_pred_result, keep_sem_bgd) + if not self.sem_seg_postprocess_before_inference: + r = retry_if_cuda_oom(sem_seg_postprocess)(r, image_size, height, width) + processed_results[-1]["sem_seg"] = r + + # panoptic segmentation inference + if self.panoptic_on: + panoptic_r = retry_if_cuda_oom(self.panoptic_inference)(mask_cls_result, mask_pred_result) + processed_results[-1]["panoptic_seg"] = panoptic_r + + # instance segmentation inference + if self.instance_on: + if self.task_switch['bbox']: + box_pred_result = bbox_postprocess(box_pred_result, input_size, image_size, height, width) + instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result, box_pred_result) + processed_results[-1]["instances"] = instance_r + if self.task_switch['caption']: + processed_results[-1]["captions"] = caption_pred_result + processed_results[-1]["masks"] = mask_pred_result + + return processed_results + + + def evaluate_retrieval(self, batched_inputs): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + img_bs = images.tensor.shape[0] + + targets = targets_grounding = queries_grounding = None + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features, target_queries=queries_grounding) + v_emb_it = outputs['pred_captions'][:,-1] + + # compute backbone score + if self.task_switch['retrieval'] and self.retrieval_emsemble: + _v_emb_it = features['res5'] + bs,nc,_,_ = _v_emb_it.shape + _v_emb_it = _v_emb_it.reshape(bs,nc,-1) + _v_emb_it = F.adaptive_avg_pool1d(_v_emb_it, 1).reshape(bs,nc) @ self.backbone_proj + + processed_results = [] + for idx, batch_data in enumerate(batched_inputs): + caption_ids = [] + t_emb_its = [] + processed_results.append({}) + for caption in batch_data['captions']: + lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(caption) + t_emb_it = lang_results['class_emb'] + caption_ids.append(batch_data['image_id']) + t_emb_its.append(t_emb_it) + + t_emb_it = torch.cat(t_emb_its, dim=0) + + image_embeds = [v_emb_it[idx].unsqueeze(0)] + if self.task_switch['retrieval'] and self.retrieval_emsemble: + image_embeds += [_v_emb_it[idx].unsqueeze(0)] + caption_results = { + 'image_embeds': image_embeds, + 'text_embeds': t_emb_it, + 'caption_ids': caption_ids, + 'image_ids': batch_data['image_id'], + } + processed_results[-1]["caption"] = caption_results + return processed_results + + def evaluate_captioning(self, batched_inputs, extra={}): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + img_bs = images.tensor.shape[0] + + if not hasattr(self, 'start_token'): + self.start_token = torch.tensor([[49406]*77], device=self.device) + + targets = targets_grounding = queries_grounding = None + features = self.backbone(images.tensor) + + captioning_mask = None + if 'captioning_mask' in batched_inputs[-1]: + captioning_mask = torch.cat([x['captioning_mask'] for x in batched_inputs]) + + extra.update({'start_token': self.start_token, 'captioning_mask': captioning_mask}) + outputs = self.sem_seg_head(features, target_queries=queries_grounding, task='captioning_infer', extra=extra) + + processed_results = [] + for idx, batch_data in enumerate(batched_inputs): + processed_results.append({}) + processed_results[-1]["captioning_token"] = outputs['pred_captionings'][idx] + processed_results[-1]["captioning_text"] = outputs['pred_texts'][idx].split('.')[0] + processed_results[-1]["image_id"] = batched_inputs[idx]['image_id'] + + return processed_results + + def evaluate_classification(self, batched_inputs): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + img_bs = images.tensor.shape[0] + + targets = targets_grounding = queries_grounding = None + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features, target_queries=queries_grounding) + + processed_results = [] + for idx, batch_data in enumerate(batched_inputs): + processed_results.append({}) + processed_results[-1]["pred_class"] = outputs['pred_logits'][idx,-1] + return processed_results + + def evaluate_grounding_baseline(self, batched_inputs, mode): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + img_bs = images.tensor.shape[0] + + targets = targets_grounding = queries_grounding = None + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features, target_queries=queries_grounding) + + mask_pred_results = outputs["pred_masks"] + caption_pred_results = outputs["pred_captions"] if self.task_switch['caption'] else [None for i in range(len(mask_pred_results))] + + # upsample masks + mask_pred_results = F.interpolate( + mask_pred_results, + size=(images.tensor.shape[-2], images.tensor.shape[-1]), + mode="bilinear", + align_corners=False, + ) + + processed_results = [] + for mask_pred_result, caption_pred_result, input_per_image, image_size in zip( + mask_pred_results, caption_pred_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + processed_results.append({}) + + mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( + mask_pred_result, image_size, height, width + )[:-1] + + texts_all = input_per_image['groundings']['texts'] + grd_masks = [] + for texts in texts_all: + if mode == 'grounding_refcoco': + self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=False, is_eval=True) + elif mode == 'grounding_phrasecut': + self.sem_seg_head.predictor.lang_encoder.get_text_embeddings(texts, name='grounding', prompt=True, is_eval=False) + t_emb = getattr(self.sem_seg_head.predictor.lang_encoder, "{}_text_embeddings".format('grounding')).t() + v_emb = caption_pred_result[:-1] + v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + vt_sim = v_emb @ t_emb + max_id = vt_sim.max(0)[1][0] + grd_masks += [mask_pred_result[max_id]] + processed_results[-1]['grounding_mask'] = torch.stack(grd_masks) + + return processed_results + + def evaluate_grounding(self, batched_inputs, mode): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, self.size_divisibility) + + extra = {} + # mask_pred_results = [] + # for idx, batch_per_image in enumerate(batched_inputs): + # grd_texts = batch_per_image['groundings']['texts'] + # grd_masks = [] + # for anno_text in grd_texts: + # gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings([anno_text[0]], name='grounding', token=False, norm=False) + # token_emb = gtext['token_emb'] + # tokens = gtext['tokens'] + + # grd_emb = token_emb[0][tokens['attention_mask'].bool()[0]] + # extra['grounding_tokens'] = grd_emb[:,None] + + # assert len(images.tensor) == 1, "grounding evaluation only support single batch size now" + # features = self.backbone(images.tensor) + # outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval') + + # pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1] + # v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1] + # t_emb = grd_emb[-1:] + + # t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) + # v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + + # temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale + # out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) + + # matched_id = out_prob.max(0)[1] + # grd_masks += [pred_gmasks[matched_id,:,:]] + # mask_pred_results += [torch.cat(grd_masks)] + + # comment for multi object inference. + mask_pred_results = [] + for idx, batch_per_image in enumerate(batched_inputs): + grd_texts = batch_per_image['groundings']['texts'] + grd_texts = [x[0] for x in grd_texts] + + gtext = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(grd_texts, name='grounding', token=False, norm=False) + token_emb = gtext['token_emb'] + tokens = gtext['tokens'] + query_emb = token_emb[tokens['attention_mask'].bool()] + extra['grounding_tokens'] = query_emb[:,None] + + features = self.backbone(images.tensor) + outputs = self.sem_seg_head(features, extra=extra, task='grounding_eval') + + pred_gmasks = outputs['pred_masks'][idx,self.num_queries:2*self.num_queries-1] + v_emb = outputs['pred_captions'][idx,self.num_queries:2*self.num_queries-1] + t_emb = gtext['class_emb'] + + t_emb = t_emb / (t_emb.norm(dim=-1, keepdim=True) + 1e-7) + v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + + temperature = self.sem_seg_head.predictor.lang_encoder.logit_scale + out_prob = vl_similarity(v_emb, t_emb, temperature=temperature) + + matched_id = out_prob.max(0)[1] + mask_pred_results += [pred_gmasks[matched_id,:,:]] + + for i in range(len(mask_pred_results)): + # upsample masks + mask_pred_results[i] = F.interpolate( + mask_pred_results[i][None,], + size=(images.tensor.shape[-2], images.tensor.shape[-1]), + mode="bilinear", + align_corners=False, + )[0] + + processed_results = [] + for mask_pred_result, input_per_image, image_size in zip( + mask_pred_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + processed_results.append({}) + + mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( + mask_pred_result, image_size, height, width + ) + processed_results[-1]['grounding_mask'] = mask_pred_result + + # compute bbox + # bbox = BitMasks(mask_pred_result > 0).get_bounding_boxes() + # bbox = BoxMode.convert(bbox.tensor, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + # processed_results[-1]['grounding_box'] = bbox + + return processed_results + + def prepare_vlp_targets(self, batched_inputs, device): + input_ids = [] + attention_mask = [] + for cnt, x in enumerate(batched_inputs): + captions = x['captions'] + randid = random.randint(0, len(captions)-1) + input_ids += x['tokens']['input_ids'][randid:randid+1] + attention_mask += x['tokens']['attention_mask'][randid:randid+1] + + input_ids = torch.stack(input_ids) + attention_mask = torch.stack(attention_mask) + tokens = {"input_ids": input_ids, "attention_mask": attention_mask} + lang_results = self.sem_seg_head.predictor.lang_encoder.get_text_token_embeddings(tokens, token=True) + + target_vlp = [] + for cnt, x in enumerate(batched_inputs): + target_dict = {} + target_dict["caption_tokens"] = lang_results['token_emb'][cnt:cnt+1] + target_dict["caption_proj"] = lang_results['class_emb'][cnt:cnt+1] + target_dict["caption_tokenids"] = lang_results['tokens']['input_ids'][cnt:cnt+1] + target_dict["caption_mask"] = lang_results['tokens']['attention_mask'][cnt:cnt+1] + target_vlp.append(target_dict) + return target_vlp + + def semantic_inference(self, mask_cls, mask_pred, keep_sem_bgd=False): + if keep_sem_bgd: + mask_cls = F.softmax(mask_cls, dim=-1) + else: + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) + return semseg + + def panoptic_inference(self, mask_cls, mask_pred): + scores, labels = F.softmax(mask_cls, dim=-1).max(-1) + mask_pred = mask_pred.sigmoid() + + keep = labels.ne(self.sem_seg_head.num_classes) & (scores > self.object_mask_threshold) + cur_scores = scores[keep] + cur_classes = labels[keep] + cur_masks = mask_pred[keep] + cur_mask_cls = mask_cls[keep] + cur_mask_cls = cur_mask_cls[:, :-1] + cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks + + h, w = cur_masks.shape[-2:] + panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) + segments_info = [] + + current_segment_id = 0 + + if cur_masks.shape[0] == 0: + # We didn't detect any mask :( + return panoptic_seg, segments_info + else: + # take argmax + cur_mask_ids = cur_prob_masks.argmax(0) + stuff_memory_list = {} + thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {} + for k in range(cur_classes.shape[0]): + pred_class = cur_classes[k].item() + isthing = pred_class in thing_dataset_id_to_contiguous_id.values() + mask_area = (cur_mask_ids == k).sum().item() + original_area = (cur_masks[k] >= 0.5).sum().item() + mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) + + if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: + if mask_area / original_area < self.overlap_threshold: + continue + + # merge stuff regions + if not isthing: + if int(pred_class) in stuff_memory_list.keys(): + panoptic_seg[mask] = stuff_memory_list[int(pred_class)] + continue + else: + stuff_memory_list[int(pred_class)] = current_segment_id + 1 + + current_segment_id += 1 + panoptic_seg[mask] = current_segment_id + + segments_info.append( + { + "id": current_segment_id, + "isthing": bool(isthing), + "category_id": int(pred_class), + } + ) + return panoptic_seg, segments_info + + def instance_inference(self, mask_cls, mask_pred, box_pred): + # mask_pred is already processed to have the same shape as original input + image_size = mask_pred.shape[-2:] + + # [Q, K] + scores = F.softmax(mask_cls, dim=-1)[:, :-1] + labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) + # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False) + scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) + + labels_per_image = labels[topk_indices] + topk_indices = (topk_indices // self.sem_seg_head.num_classes) + # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1) + mask_pred = mask_pred[topk_indices] + if box_pred is not None: + box_pred = box_pred[topk_indices] + + # if this is panoptic segmentation, we only keep the "thing" classes + if self.panoptic_on: + thing_dataset_id_to_contiguous_id = self.metadata.thing_dataset_id_to_contiguous_id if hasattr(self.metadata, 'thing_dataset_id_to_contiguous_id') else {} + keep = torch.zeros_like(scores_per_image).bool() + for i, lab in enumerate(labels_per_image): + keep[i] = lab in thing_dataset_id_to_contiguous_id.values() + + scores_per_image = scores_per_image[keep] + labels_per_image = labels_per_image[keep] + mask_pred = mask_pred[keep] + + if box_pred is not None: + box_pred = box_pred[keep] + + result = Instances(image_size) + # mask (before sigmoid) + result.pred_masks = (mask_pred > 0).float() + # result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) + # Uncomment the following to get boxes from masks (this is slow) + + if box_pred is not None: + result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() + else: + result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) + + # calculate average mask prob + mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) + result.scores = scores_per_image * mask_scores_per_image + result.pred_classes = labels_per_image + + return result + + +@register_model +def get_segmentation_model(cfg, **kwargs): + return X_Decoder_Model(cfg) \ No newline at end of file diff --git a/xdecoder/backbone/__init__.py b/xdecoder/backbone/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..aac67442cdce051f7f9de6068990eb388b8dd3bb --- /dev/null +++ b/xdecoder/backbone/__init__.py @@ -0,0 +1,7 @@ +from .build import build_backbone + +from .resnet import * +from .swin import * +from .focal import * +from .focal_dw import * +from .backbone import * \ No newline at end of file diff --git a/xdecoder/backbone/__pycache__/__init__.cpython-38.pyc b/xdecoder/backbone/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bc2803bc74adb5e56bd8813e9613cd467866792 Binary files /dev/null and b/xdecoder/backbone/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/backbone/__pycache__/backbone.cpython-38.pyc b/xdecoder/backbone/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4841be625a3107ca3660cfd47bb8cbdb0b73ae9 Binary files /dev/null and b/xdecoder/backbone/__pycache__/backbone.cpython-38.pyc differ diff --git a/xdecoder/backbone/__pycache__/build.cpython-38.pyc b/xdecoder/backbone/__pycache__/build.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a6eec7234669d3a1c23ba6f5089b33e535e7ca4 Binary files /dev/null and b/xdecoder/backbone/__pycache__/build.cpython-38.pyc differ diff --git a/xdecoder/backbone/__pycache__/focal.cpython-38.pyc b/xdecoder/backbone/__pycache__/focal.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b30a1b496ac17b49140ba85b2dc71896f72f277 Binary files /dev/null and b/xdecoder/backbone/__pycache__/focal.cpython-38.pyc differ diff --git a/xdecoder/backbone/__pycache__/focal_dw.cpython-38.pyc b/xdecoder/backbone/__pycache__/focal_dw.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6564f15b345316c10577d47a4b060f611120de56 Binary files /dev/null and b/xdecoder/backbone/__pycache__/focal_dw.cpython-38.pyc differ diff --git a/xdecoder/backbone/__pycache__/registry.cpython-38.pyc b/xdecoder/backbone/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..262db2e1fa3699b90fd15f0d2e1ffe9d6e503d23 Binary files /dev/null and b/xdecoder/backbone/__pycache__/registry.cpython-38.pyc differ diff --git a/xdecoder/backbone/__pycache__/resnet.cpython-38.pyc b/xdecoder/backbone/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b9ac54cec5ecda11a40a70d1602f454c9564182 Binary files /dev/null and b/xdecoder/backbone/__pycache__/resnet.cpython-38.pyc differ diff --git a/xdecoder/backbone/__pycache__/swin.cpython-38.pyc b/xdecoder/backbone/__pycache__/swin.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b89ab8de1bdede2b3be3eca283f143719925207 Binary files /dev/null and b/xdecoder/backbone/__pycache__/swin.cpython-38.pyc differ diff --git a/xdecoder/backbone/backbone.py b/xdecoder/backbone/backbone.py new file mode 100755 index 0000000000000000000000000000000000000000..503f74a69288b3696bebf12992f21ad5781e47aa --- /dev/null +++ b/xdecoder/backbone/backbone.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch.nn as nn + +from detectron2.modeling import ShapeSpec + +__all__ = ["Backbone"] + + +class Backbone(nn.Module): + """ + Abstract base class for network backbones. + """ + + def __init__(self): + """ + The `__init__` method of any subclass can specify its own set of arguments. + """ + super().__init__() + + def forward(self): + """ + Subclasses must override this method, but adhere to the same return type. + + Returns: + dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor + """ + pass + + @property + def size_divisibility(self) -> int: + """ + Some backbones require the input height and width to be divisible by a + specific integer. This is typically true for encoder / decoder type networks + with lateral connection (e.g., FPN) for which feature maps need to match + dimension in the "bottom up" and "top down" paths. Set to 0 if no specific + input size divisibility is required. + """ + return 0 + + def output_shape(self): + """ + Returns: + dict[str->ShapeSpec] + """ + # this is a backward-compatible default + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } diff --git a/xdecoder/backbone/build.py b/xdecoder/backbone/build.py new file mode 100755 index 0000000000000000000000000000000000000000..a559fa6a010d3379ff5fcbeb43c510122988735f --- /dev/null +++ b/xdecoder/backbone/build.py @@ -0,0 +1,11 @@ +from .registry import model_entrypoints +from .registry import is_model + +from .backbone import * + +def build_backbone(config, **kwargs): + model_name = config['MODEL']['BACKBONE']['NAME'] + if not is_model(model_name): + raise ValueError(f'Unkown model: {model_name}') + + return model_entrypoints(model_name)(config, **kwargs) \ No newline at end of file diff --git a/xdecoder/backbone/focal.py b/xdecoder/backbone/focal.py new file mode 100755 index 0000000000000000000000000000000000000000..eb08555d2f5a036d175ee94033d8cae30d0ff959 --- /dev/null +++ b/xdecoder/backbone/focal.py @@ -0,0 +1,692 @@ +# -------------------------------------------------------- +# FocalNet for Semantic Segmentation +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Jianwei Yang +# -------------------------------------------------------- +import math +import time +import numpy as np +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from detectron2.utils.file_io import PathManager +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec + +from .registry import register_backbone + +logger = logging.getLogger(__name__) + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class FocalModulation(nn.Module): + """ Focal Modulation + + Args: + dim (int): Number of input channels. + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + focal_level (int): Number of focal levels + focal_window (int): Focal window size at focal level 1 + focal_factor (int, default=2): Step to increase the focal window + use_postln (bool, default=False): Whether use post-modulation layernorm + """ + + def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): + + super().__init__() + self.dim = dim + + # specific args for focalv3 + self.focal_level = focal_level + self.focal_window = focal_window + self.focal_factor = focal_factor + self.use_postln_in_modulation = use_postln_in_modulation + self.scaling_modulator = scaling_modulator + + self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) + self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) + + self.act = nn.GELU() + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.focal_layers = nn.ModuleList() + + if self.use_postln_in_modulation: + self.ln = nn.LayerNorm(dim) + + for k in range(self.focal_level): + kernel_size = self.focal_factor*k + self.focal_window + self.focal_layers.append( + nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, + padding=kernel_size//2, bias=False), + nn.GELU(), + ) + ) + + def forward(self, x): + """ Forward function. + + Args: + x: input features with shape of (B, H, W, C) + """ + B, nH, nW, C = x.shape + x = self.f(x) + x = x.permute(0, 3, 1, 2).contiguous() + q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) + + ctx_all = 0 + for l in range(self.focal_level): + ctx = self.focal_layers[l](ctx) + ctx_all = ctx_all + ctx*gates[:, l:l+1] + ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) + ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] + + if self.scaling_modulator: + ctx_all = ctx_all / (self.focal_level + 1) + + x_out = q * self.h(ctx_all) + x_out = x_out.permute(0, 2, 3, 1).contiguous() + if self.use_postln_in_modulation: + x_out = self.ln(x_out) + x_out = self.proj(x_out) + x_out = self.proj_drop(x_out) + return x_out + +class FocalModulationBlock(nn.Module): + """ Focal Modulation Block. + + Args: + dim (int): Number of input channels. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + focal_level (int): number of focal levels + focal_window (int): focal kernel size at level 1 + """ + + def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + focal_level=2, focal_window=9, + use_postln=False, use_postln_in_modulation=False, + scaling_modulator=False, + use_layerscale=False, + layerscale_value=1e-4): + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.focal_window = focal_window + self.focal_level = focal_level + self.use_postln = use_postln + self.use_layerscale = use_layerscale + + self.norm1 = norm_layer(dim) + self.modulation = FocalModulation( + dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + self.gamma_1 = 1.0 + self.gamma_2 = 1.0 + if self.use_layerscale: + self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + if not self.use_postln: + x = self.norm1(x) + x = x.view(B, H, W, C) + + # FM + x = self.modulation(x).view(B, H * W, C) + if self.use_postln: + x = self.norm1(x) + + # FFN + x = shortcut + self.drop_path(self.gamma_1 * x) + + if self.use_postln: + x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x))) + else: + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + + return x + +class BasicLayer(nn.Module): + """ A basic focal modulation layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + focal_level (int): Number of focal levels + focal_window (int): Focal window size at focal level 1 + use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, + dim, + depth, + mlp_ratio=4., + drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + focal_window=9, + focal_level=2, + use_conv_embed=False, + use_postln=False, + use_postln_in_modulation=False, + scaling_modulator=False, + use_layerscale=False, + use_checkpoint=False + ): + super().__init__() + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + FocalModulationBlock( + dim=dim, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + focal_window=focal_window, + focal_level=focal_level, + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + scaling_modulator=scaling_modulator, + use_layerscale=use_layerscale, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + patch_size=2, + in_chans=dim, embed_dim=2*dim, + use_conv_embed=use_conv_embed, + norm_layer=norm_layer, + is_stem=False + ) + + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) + x_down = self.downsample(x_reshaped) + x_down = x_down.flatten(2).transpose(1, 2) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False + is_stem (bool): Is the stem block or not. + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if use_conv_embed: + # if we choose to use conv embedding, then we treat the stem and non-stem differently + if is_stem: + kernel_size = 7; padding = 2; stride = 4 + else: + kernel_size = 3; padding = 1; stride = 2 + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + else: + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class FocalNet(nn.Module): + """ FocalNet backbone. + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + drop_rate (float): Dropout rate. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + focal_levels (Sequence[int]): Number of focal levels at four stages + focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages + use_conv_embed (bool): Whether use overlapped convolution for patch embedding + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=1600, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + out_indices=[0, 1, 2, 3], + frozen_stages=-1, + focal_levels=[2,2,2,2], + focal_windows=[9,9,9,9], + use_conv_embed=False, + use_postln=False, + use_postln_in_modulation=False, + scaling_modulator=False, + use_layerscale=False, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + use_conv_embed=use_conv_embed, is_stem=True) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, + focal_window=focal_windows[i_layer], + focal_level=focal_levels[i_layer], + use_conv_embed=use_conv_embed, + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + scaling_modulator=scaling_modulator, + use_layerscale=use_layerscale, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): + model_dict = self.state_dict() + + missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] + logger.info(f'=> Missed keys {missed_dict}') + unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] + logger.info(f'=> Unexpected keys {unexpected_dict}') + + pretrained_dict = { + k: v for k, v in pretrained_dict.items() + if k in model_dict.keys() + } + + need_init_state_dict = {} + for k, v in pretrained_dict.items(): + need_init = ( + ( + k.split('.')[0] in pretrained_layers + or pretrained_layers[0] == '*' + ) + and 'relative_position_index' not in k + and 'attn_mask' not in k + ) + + if need_init: + # if verbose: + # logger.info(f'=> init {k} from {pretrained}') + + if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): + table_pretrained = v + table_current = model_dict[k] + fsize1 = table_pretrained.shape[2] + fsize2 = table_current.shape[2] + + # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv + if fsize1 < fsize2: + table_pretrained_resized = torch.zeros(table_current.shape) + table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained + v = table_pretrained_resized + elif fsize1 > fsize2: + table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] + v = table_pretrained_resized + + + if ("modulation.f" in k or "pre_conv" in k): + table_pretrained = v + table_current = model_dict[k] + if table_pretrained.shape != table_current.shape: + if len(table_pretrained.shape) == 2: + dim = table_pretrained.shape[1] + assert table_current.shape[1] == dim + L1 = table_pretrained.shape[0] + L2 = table_current.shape[0] + + if L1 < L2: + table_pretrained_resized = torch.zeros(table_current.shape) + # copy for linear project + table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] + # copy for global token gating + table_pretrained_resized[-1] = table_pretrained[-1] + # copy for first multiple focal levels + table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] + # reassign pretrained weights + v = table_pretrained_resized + elif L1 > L2: + raise NotImplementedError + elif len(table_pretrained.shape) == 1: + dim = table_pretrained.shape[0] + L1 = table_pretrained.shape[0] + L2 = table_current.shape[0] + if L1 < L2: + table_pretrained_resized = torch.zeros(table_current.shape) + # copy for linear project + table_pretrained_resized[:dim] = table_pretrained[:dim] + # copy for global token gating + table_pretrained_resized[-1] = table_pretrained[-1] + # copy for first multiple focal levels + # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] + # reassign pretrained weights + v = table_pretrained_resized + elif L1 > L2: + raise NotImplementedError + + need_init_state_dict[k] = v + + self.load_state_dict(need_init_state_dict, strict=False) + + + def forward(self, x): + """Forward function.""" + tic = time.time() + x = self.patch_embed(x) + Wh, Ww = x.size(2), x.size(3) + + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs["res{}".format(i + 2)] = out + + if len(self.out_indices) == 0: + outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + + toc = time.time() + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(FocalNet, self).train(mode) + self._freeze_stages() + + +class D2FocalNet(FocalNet, Backbone): + def __init__(self, cfg, input_shape): + + pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] + patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] + in_chans = 3 + embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] + depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] + mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] + drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] + drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] + norm_layer = nn.LayerNorm + patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] + use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] + out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] + scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) + + super().__init__( + pretrain_img_size, + patch_size, + in_chans, + embed_dim, + depths, + mlp_ratio, + drop_rate, + drop_path_rate, + norm_layer, + patch_norm, + out_indices, + focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], + focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], + use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], + use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], + use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], + scaling_modulator=scaling_modulator, + use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], + use_checkpoint=use_checkpoint, + ) + + self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert ( + x.dim() == 4 + ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + y = super().forward(x) + for k in y.keys(): + if k in self._out_features: + outputs[k] = y[k] + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + @property + def size_divisibility(self): + return 32 + +@register_backbone +def get_focal_backbone(cfg): + focal = D2FocalNet(cfg['MODEL'], 224) + + if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: + filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] + logger.info(f'=> init from {filename}') + with PathManager.open(filename, "rb") as f: + ckpt = torch.load(f)['model'] + focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) + + return focal \ No newline at end of file diff --git a/xdecoder/backbone/focal_dw.py b/xdecoder/backbone/focal_dw.py new file mode 100755 index 0000000000000000000000000000000000000000..4306ec6fc347a8e5798f79ba9e08e1a1d687fbb5 --- /dev/null +++ b/xdecoder/backbone/focal_dw.py @@ -0,0 +1,789 @@ +# -------------------------------------------------------- +# FocalNet for Semantic Segmentation +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Jianwei Yang +# -------------------------------------------------------- +import math +import time +import numpy as np +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from detectron2.utils.file_io import PathManager +from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec + +from .registry import register_backbone + +logger = logging.getLogger(__name__) + +class Mlp(nn.Module): + """ Multilayer perceptron.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class FocalModulation(nn.Module): + """ Focal Modulation + + Args: + dim (int): Number of input channels. + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + focal_level (int): Number of focal levels + focal_window (int): Focal window size at focal level 1 + focal_factor (int, default=2): Step to increase the focal window + use_postln (bool, default=False): Whether use post-modulation layernorm + """ + + def __init__(self, dim, proj_drop=0., focal_level=2, focal_window=7, focal_factor=2, use_postln=False, use_postln_in_modulation=False, scaling_modulator=False): + + super().__init__() + self.dim = dim + + # specific args for focalv3 + self.focal_level = focal_level + self.focal_window = focal_window + self.focal_factor = focal_factor + self.use_postln_in_modulation = use_postln_in_modulation + self.scaling_modulator = scaling_modulator + + self.f = nn.Linear(dim, 2*dim+(self.focal_level+1), bias=True) + self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, groups=1, bias=True) + + self.act = nn.GELU() + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.focal_layers = nn.ModuleList() + + if self.use_postln_in_modulation: + self.ln = nn.LayerNorm(dim) + + for k in range(self.focal_level): + kernel_size = self.focal_factor*k + self.focal_window + self.focal_layers.append( + nn.Sequential( + nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, groups=dim, + padding=kernel_size//2, bias=False), + nn.GELU(), + ) + ) + + def forward(self, x): + """ Forward function. + + Args: + x: input features with shape of (B, H, W, C) + """ + B, nH, nW, C = x.shape + x = self.f(x) + x = x.permute(0, 3, 1, 2).contiguous() + q, ctx, gates = torch.split(x, (C, C, self.focal_level+1), 1) + + ctx_all = 0 + for l in range(self.focal_level): + ctx = self.focal_layers[l](ctx) + ctx_all = ctx_all + ctx*gates[:, l:l+1] + ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) + ctx_all = ctx_all + ctx_global*gates[:,self.focal_level:] + + if self.scaling_modulator: + ctx_all = ctx_all / (self.focal_level + 1) + + x_out = q * self.h(ctx_all) + x_out = x_out.permute(0, 2, 3, 1).contiguous() + if self.use_postln_in_modulation: + x_out = self.ln(x_out) + x_out = self.proj(x_out) + x_out = self.proj_drop(x_out) + return x_out + +class FocalModulationBlock(nn.Module): + """ Focal Modulation Block. + + Args: + dim (int): Number of input channels. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + focal_level (int): number of focal levels + focal_window (int): focal kernel size at level 1 + """ + + def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + focal_level=2, focal_window=9, + use_postln=False, use_postln_in_modulation=False, + scaling_modulator=False, + use_layerscale=False, + layerscale_value=1e-4): + super().__init__() + self.dim = dim + self.mlp_ratio = mlp_ratio + self.focal_window = focal_window + self.focal_level = focal_level + self.use_postln = use_postln + self.use_layerscale = use_layerscale + + self.dw1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) + self.norm1 = norm_layer(dim) + self.modulation = FocalModulation( + dim, focal_window=self.focal_window, focal_level=self.focal_level, proj_drop=drop, use_postln_in_modulation=use_postln_in_modulation, scaling_modulator=scaling_modulator + ) + + self.dw2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + + self.gamma_1 = 1.0 + self.gamma_2 = 1.0 + if self.use_layerscale: + self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) + self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() + x = x + self.dw1(x) + x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) + + shortcut = x + if not self.use_postln: + x = self.norm1(x) + x = x.view(B, H, W, C) + + # FM + x = self.modulation(x).view(B, H * W, C) + x = shortcut + self.drop_path(self.gamma_1 * x) + if self.use_postln: + x = self.norm1(x) + + x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous() + x = x + self.dw2(x) + x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C) + + if not self.use_postln: + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_2 * self.mlp(x)) + x = self.norm2(x) + + return x + +class BasicLayer(nn.Module): + """ A basic focal modulation layer for one stage. + + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + focal_level (int): Number of focal levels + focal_window (int): Focal window size at focal level 1 + use_conv_embed (bool): Use overlapped convolution for patch embedding or now. Default: False + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, + dim, + depth, + mlp_ratio=4., + drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + focal_window=9, + focal_level=2, + use_conv_embed=False, + use_postln=False, + use_postln_in_modulation=False, + scaling_modulator=False, + use_layerscale=False, + use_checkpoint=False, + use_pre_norm=False, + ): + super().__init__() + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + FocalModulationBlock( + dim=dim, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + focal_window=focal_window, + focal_level=focal_level, + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + scaling_modulator=scaling_modulator, + use_layerscale=use_layerscale, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + patch_size=2, + in_chans=dim, embed_dim=2*dim, + use_conv_embed=use_conv_embed, + norm_layer=norm_layer, + is_stem=False, + use_pre_norm=use_pre_norm + ) + + else: + self.downsample = None + + def forward(self, x, H, W): + """ Forward function. + + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x_reshaped = x.transpose(1, 2).view(x.shape[0], x.shape[-1], H, W) + x_down = self.downsample(x_reshaped) + x_down = x_down.flatten(2).transpose(1, 2) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +# class PatchEmbed(nn.Module): +# r""" Image to Patch Embedding + +# Args: +# img_size (int): Image size. Default: 224. +# patch_size (int): Patch token size. Default: 4. +# in_chans (int): Number of input image channels. Default: 3. +# embed_dim (int): Number of linear projection output channels. Default: 96. +# norm_layer (nn.Module, optional): Normalization layer. Default: None +# """ + +# def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96, +# use_conv_embed=False, norm_layer=None, is_stem=False, use_pre_norm=False): +# super().__init__() +# patch_size = to_2tuple(patch_size) +# patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] +# self.img_size = img_size +# self.patch_size = patch_size +# self.patches_resolution = patches_resolution +# self.num_patches = patches_resolution[0] * patches_resolution[1] + +# self.in_chans = in_chans +# self.embed_dim = embed_dim +# self.use_pre_norm = use_pre_norm + +# if use_conv_embed: +# # if we choose to use conv embedding, then we treat the stem and non-stem differently +# if is_stem: +# kernel_size = 7; padding = 3; stride = 4 +# else: +# kernel_size = 3; padding = 1; stride = 2 +# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) +# else: +# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + +# if self.use_pre_norm: +# if norm_layer is not None: +# self.norm = norm_layer(in_chans) +# else: +# self.norm = None +# else: +# if norm_layer is not None: +# self.norm = norm_layer(embed_dim) +# else: +# self.norm = None + +# def forward(self, x): +# B, C, H, W = x.shape +# # FIXME look at relaxing size constraints +# assert H == self.img_size[0] and W == self.img_size[1], \ +# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + +# if self.use_pre_norm: +# if self.norm is not None: +# x = x.flatten(2).transpose(1, 2) # B Ph*Pw C +# x = self.norm(x).transpose(1, 2).view(B, C, H, W) +# x = self.proj(x).flatten(2).transpose(1, 2) +# else: +# x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C +# if self.norm is not None: +# x = self.norm(x) +# return x + +# def flops(self): +# Ho, Wo = self.patches_resolution +# flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) +# if self.norm is not None: +# flops += Ho * Wo * self.embed_dim +# return flops + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + use_conv_embed (bool): Whether use overlapped convolution for patch embedding. Default: False + is_stem (bool): Is the stem block or not. + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, use_conv_embed=False, is_stem=False, use_pre_norm=False): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + self.use_pre_norm = use_pre_norm + + if use_conv_embed: + # if we choose to use conv embedding, then we treat the stem and non-stem differently + if is_stem: + kernel_size = 7; padding = 3; stride = 4 + else: + kernel_size = 3; padding = 1; stride = 2 + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding) + else: + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + if self.use_pre_norm: + if norm_layer is not None: + self.norm = norm_layer(in_chans) + else: + self.norm = None + else: + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + B, C, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + if self.use_pre_norm: + if self.norm is not None: + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + x = self.norm(x).transpose(1, 2).view(B, C, H, W) + x = self.proj(x) + else: + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class FocalNet(nn.Module): + """ FocalNet backbone. + + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + drop_rate (float): Dropout rate. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + focal_levels (Sequence[int]): Number of focal levels at four stages + focal_windows (Sequence[int]): Focal window sizes at first focal level at four stages + use_conv_embed (bool): Whether use overlapped convolution for patch embedding + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + pretrain_img_size=1600, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + patch_norm=True, + out_indices=[0, 1, 2, 3], + frozen_stages=-1, + focal_levels=[2,2,2,2], + focal_windows=[9,9,9,9], + use_pre_norms=[False, False, False, False], + use_conv_embed=False, + use_postln=False, + use_postln_in_modulation=False, + scaling_modulator=False, + use_layerscale=False, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + use_conv_embed=use_conv_embed, is_stem=True, use_pre_norm=False) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + mlp_ratio=mlp_ratio, + drop=drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchEmbed if (i_layer < self.num_layers - 1) else None, + focal_window=focal_windows[i_layer], + focal_level=focal_levels[i_layer], + use_pre_norm=use_pre_norms[i_layer], + use_conv_embed=use_conv_embed, + use_postln=use_postln, + use_postln_in_modulation=use_postln_in_modulation, + scaling_modulator=scaling_modulator, + use_layerscale=use_layerscale, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + # self.norm = norm_layer(num_features[-1]) + + # add a norm layer for each output + for i_layer in self.out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + logger = get_root_logger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): + model_dict = self.state_dict() + + missed_dict = [k for k in model_dict.keys() if k not in pretrained_dict] + logger.info(f'=> Missed keys {missed_dict}') + unexpected_dict = [k for k in pretrained_dict.keys() if k not in model_dict] + logger.info(f'=> Unexpected keys {unexpected_dict}') + + pretrained_dict = { + k: v for k, v in pretrained_dict.items() + if k in model_dict.keys() + } + + need_init_state_dict = {} + for k, v in pretrained_dict.items(): + need_init = ( + ( + k.split('.')[0] in pretrained_layers + or pretrained_layers[0] == '*' + ) + and 'relative_position_index' not in k + and 'attn_mask' not in k + ) + + if need_init: + # if verbose: + # logger.info(f'=> init {k} from {pretrained}') + + if ('pool_layers' in k) or ('focal_layers' in k) and v.size() != model_dict[k].size(): + table_pretrained = v + table_current = model_dict[k] + fsize1 = table_pretrained.shape[2] + fsize2 = table_current.shape[2] + + # NOTE: different from interpolation used in self-attention, we use padding or clipping for focal conv + if fsize1 < fsize2: + table_pretrained_resized = torch.zeros(table_current.shape) + table_pretrained_resized[:, :, (fsize2-fsize1)//2:-(fsize2-fsize1)//2, (fsize2-fsize1)//2:-(fsize2-fsize1)//2] = table_pretrained + v = table_pretrained_resized + elif fsize1 > fsize2: + table_pretrained_resized = table_pretrained[:, :, (fsize1-fsize2)//2:-(fsize1-fsize2)//2, (fsize1-fsize2)//2:-(fsize1-fsize2)//2] + v = table_pretrained_resized + + + if ("modulation.f" in k or "pre_conv" in k): + table_pretrained = v + table_current = model_dict[k] + if table_pretrained.shape != table_current.shape: + if len(table_pretrained.shape) == 2: + dim = table_pretrained.shape[1] + assert table_current.shape[1] == dim + L1 = table_pretrained.shape[0] + L2 = table_current.shape[0] + + if L1 < L2: + table_pretrained_resized = torch.zeros(table_current.shape) + # copy for linear project + table_pretrained_resized[:2*dim] = table_pretrained[:2*dim] + # copy for global token gating + table_pretrained_resized[-1] = table_pretrained[-1] + # copy for first multiple focal levels + table_pretrained_resized[2*dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] + # reassign pretrained weights + v = table_pretrained_resized + elif L1 > L2: + raise NotImplementedError + elif len(table_pretrained.shape) == 1: + dim = table_pretrained.shape[0] + L1 = table_pretrained.shape[0] + L2 = table_current.shape[0] + if L1 < L2: + table_pretrained_resized = torch.zeros(table_current.shape) + # copy for linear project + table_pretrained_resized[:dim] = table_pretrained[:dim] + # copy for global token gating + table_pretrained_resized[-1] = table_pretrained[-1] + # copy for first multiple focal levels + # table_pretrained_resized[dim:2*dim+(L1-2*dim-1)] = table_pretrained[2*dim:-1] + # reassign pretrained weights + v = table_pretrained_resized + elif L1 > L2: + raise NotImplementedError + + need_init_state_dict[k] = v + + self.load_state_dict(need_init_state_dict, strict=False) + + + def forward(self, x): + """Forward function.""" + tic = time.time() + x = self.patch_embed(x) + Wh, Ww = x.size(2), x.size(3) + + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs["res{}".format(i + 2)] = out + + if len(self.out_indices) == 0: + outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + + toc = time.time() + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(FocalNet, self).train(mode) + self._freeze_stages() + + +class D2FocalNet(FocalNet, Backbone): + def __init__(self, cfg, input_shape): + + pretrain_img_size = cfg['BACKBONE']['FOCAL']['PRETRAIN_IMG_SIZE'] + patch_size = cfg['BACKBONE']['FOCAL']['PATCH_SIZE'] + in_chans = 3 + embed_dim = cfg['BACKBONE']['FOCAL']['EMBED_DIM'] + depths = cfg['BACKBONE']['FOCAL']['DEPTHS'] + mlp_ratio = cfg['BACKBONE']['FOCAL']['MLP_RATIO'] + drop_rate = cfg['BACKBONE']['FOCAL']['DROP_RATE'] + drop_path_rate = cfg['BACKBONE']['FOCAL']['DROP_PATH_RATE'] + norm_layer = nn.LayerNorm + patch_norm = cfg['BACKBONE']['FOCAL']['PATCH_NORM'] + use_checkpoint = cfg['BACKBONE']['FOCAL']['USE_CHECKPOINT'] + out_indices = cfg['BACKBONE']['FOCAL']['OUT_INDICES'] + scaling_modulator = cfg['BACKBONE']['FOCAL'].get('SCALING_MODULATOR', False) + + super().__init__( + pretrain_img_size, + patch_size, + in_chans, + embed_dim, + depths, + mlp_ratio, + drop_rate, + drop_path_rate, + norm_layer, + patch_norm, + out_indices, + focal_levels=cfg['BACKBONE']['FOCAL']['FOCAL_LEVELS'], + focal_windows=cfg['BACKBONE']['FOCAL']['FOCAL_WINDOWS'], + use_conv_embed=cfg['BACKBONE']['FOCAL']['USE_CONV_EMBED'], + use_postln=cfg['BACKBONE']['FOCAL']['USE_POSTLN'], + use_postln_in_modulation=cfg['BACKBONE']['FOCAL']['USE_POSTLN_IN_MODULATION'], + scaling_modulator=scaling_modulator, + use_layerscale=cfg['BACKBONE']['FOCAL']['USE_LAYERSCALE'], + use_checkpoint=use_checkpoint, + ) + + self._out_features = cfg['BACKBONE']['FOCAL']['OUT_FEATURES'] + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert ( + x.dim() == 4 + ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + y = super().forward(x) + for k in y.keys(): + if k in self._out_features: + outputs[k] = y[k] + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + @property + def size_divisibility(self): + return 32 + +@register_backbone +def get_focal_backbone(cfg): + focal = D2FocalNet(cfg['MODEL'], 224) + + if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: + filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] + logger.info(f'=> init from {filename}') + with PathManager.open(filename, "rb") as f: + ckpt = torch.load(f)['model'] + focal.load_weights(ckpt, cfg['MODEL']['BACKBONE']['FOCAL'].get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) + + return focal \ No newline at end of file diff --git a/xdecoder/backbone/registry.py b/xdecoder/backbone/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..9e19cc8068fff5f5de219c0739594b404d837e00 --- /dev/null +++ b/xdecoder/backbone/registry.py @@ -0,0 +1,14 @@ +_model_entrypoints = {} + + +def register_backbone(fn): + module_name_split = fn.__module__.split('.') + model_name = module_name_split[-1] + _model_entrypoints[model_name] = fn + return fn + +def model_entrypoints(model_name): + return _model_entrypoints[model_name] + +def is_model(model_name): + return model_name in _model_entrypoints diff --git a/xdecoder/backbone/resnet.py b/xdecoder/backbone/resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..dbfaa85ccb7937b93fc7f8a0ca57cc2e785ec2e6 --- /dev/null +++ b/xdecoder/backbone/resnet.py @@ -0,0 +1,731 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import pickle +import numpy as np +from typing import Any, Dict +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn.functional as F +from torch import nn + + +from .backbone import Backbone +from .registry import register_backbone + +from detectron2.layers import ( + CNNBlockBase, + Conv2d, + DeformConv, + ModulatedDeformConv, + ShapeSpec, + get_norm, +) +from detectron2.utils.file_io import PathManager + +__all__ = [ + "ResNetBlockBase", + "BasicBlock", + "BottleneckBlock", + "DeformBottleneckBlock", + "BasicStem", + "ResNet", + "make_stage", + "get_resnet_backbone", +] + + +class BasicBlock(CNNBlockBase): + """ + The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`, + with two 3x3 conv layers and a projection shortcut if needed. + """ + + def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the first conv. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + self.conv2 = Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + out = self.conv2(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block used by ResNet-50, 101 and 152 + defined in :paper:`ResNet`. It contains 3 conv layers with kernels + 1x1, 3x3, 1x1, and a projection shortcut if needed. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + ): + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient γ is initialized + # to be 1, except for each residual block's last BN + # where γ is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + out = self.conv2(out) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class DeformBottleneckBlock(CNNBlockBase): + """ + Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv ` + in the 3x3 convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + deform_modulated=False, + deform_num_groups=1, + ): + super().__init__(in_channels, out_channels, stride) + self.deform_modulated = deform_modulated + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + if deform_modulated: + deform_conv_op = ModulatedDeformConv + # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size + offset_channels = 27 + else: + deform_conv_op = DeformConv + offset_channels = 18 + + self.conv2_offset = Conv2d( + bottleneck_channels, + offset_channels * deform_num_groups, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + dilation=dilation, + ) + self.conv2 = deform_conv_op( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + deformable_groups=deform_num_groups, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + nn.init.constant_(self.conv2_offset.weight, 0) + nn.init.constant_(self.conv2_offset.bias, 0) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + if self.deform_modulated: + offset_mask = self.conv2_offset(out) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + out = self.conv2(out, offset, mask) + else: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BasicStem(CNNBlockBase): + """ + The standard ResNet stem (layers before the first residual block), + with a conv, relu and max_pool. + """ + + def __init__(self, in_channels=3, out_channels=64, norm="BN"): + """ + Args: + norm (str or callable): norm after the first conv layer. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, 4) + self.in_channels = in_channels + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False, + norm=get_norm(norm, out_channels), + ) + weight_init.c2_msra_fill(self.conv1) + + def forward(self, x): + x = self.conv1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +class ResNet(Backbone): + """ + Implement :paper:`ResNet`. + """ + + def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[CNNBlockBase]]): several (typically 4) stages, + each contains multiple :class:`CNNBlockBase`. + num_classes (None or int): if None, will not perform classification. + Otherwise, will create a linear layer. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + """ + super().__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stage_names, self.stages = [], [] + + if out_features is not None: + # Avoid keeping unused layers in this module. They consume extra memory + # and may cause allreduce to fail + num_stages = max( + [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features] + ) + stages = stages[:num_stages] + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = "res" + str(i + 2) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stage_names.append(name) + self.stages.append(stage) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels + self.stage_names = tuple(self.stage_names) # Make it static for scripting + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "The 1000-way fully-connected layer is initialized by + # drawing weights from a zero-mean Gaussian with standard deviation of 0.01." + nn.init.normal_(self.linear.weight, std=0.01) + name = "linear" + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format(", ".join(children)) + self.freeze(freeze_at) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for name, stage in zip(self.stage_names, self.stages): + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + def freeze(self, freeze_at=0): + """ + Freeze the first several stages of the ResNet. Commonly used in + fine-tuning. + + Layers that produce the same feature map spatial size are defined as one + "stage" by :paper:`FPN`. + + Args: + freeze_at (int): number of stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + one residual stage, etc. + + Returns: + nn.Module: this ResNet itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + @staticmethod + def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs): + """ + Create a list of blocks of the same type that forms one ResNet stage. + + Args: + block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this + stage. A module of this type must not change spatial resolution of inputs unless its + stride != 1. + num_blocks (int): number of blocks in this stage + in_channels (int): input channels of the entire stage. + out_channels (int): output channels of **every block** in the stage. + kwargs: other arguments passed to the constructor of + `block_class`. If the argument name is "xx_per_block", the + argument is a list of values to be passed to each block in the + stage. Otherwise, the same argument is passed to every block + in the stage. + + Returns: + list[CNNBlockBase]: a list of block module. + + Examples: + :: + stage = ResNet.make_stage( + BottleneckBlock, 3, in_channels=16, out_channels=64, + bottleneck_channels=16, num_groups=1, + stride_per_block=[2, 1, 1], + dilations_per_block=[1, 1, 2] + ) + + Usually, layers that produce the same feature map spatial size are defined as one + "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should + all be 1. + """ + blocks = [] + for i in range(num_blocks): + curr_kwargs = {} + for k, v in kwargs.items(): + if k.endswith("_per_block"): + assert len(v) == num_blocks, ( + f"Argument '{k}' of make_stage should have the " + f"same length as num_blocks={num_blocks}." + ) + newk = k[: -len("_per_block")] + assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!" + curr_kwargs[newk] = v[i] + else: + curr_kwargs[k] = v + + blocks.append( + block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs) + ) + in_channels = out_channels + return blocks + + @staticmethod + def make_default_stages(depth, block_class=None, **kwargs): + """ + Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). + If it doesn't create the ResNet variant you need, please use :meth:`make_stage` + instead for fine-grained customization. + + Args: + depth (int): depth of ResNet + block_class (type): the CNN block class. Has to accept + `bottleneck_channels` argument for depth > 50. + By default it is BasicBlock or BottleneckBlock, based on the + depth. + kwargs: + other arguments to pass to `make_stage`. Should not contain + stride and channels, as they are predefined for each depth. + + Returns: + list[list[CNNBlockBase]]: modules in all stages; see arguments of + :class:`ResNet.__init__`. + """ + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + if block_class is None: + block_class = BasicBlock if depth < 50 else BottleneckBlock + if depth < 50: + in_channels = [64, 64, 128, 256] + out_channels = [64, 128, 256, 512] + else: + in_channels = [64, 256, 512, 1024] + out_channels = [256, 512, 1024, 2048] + ret = [] + for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels): + if depth >= 50: + kwargs["bottleneck_channels"] = o // 4 + ret.append( + ResNet.make_stage( + block_class=block_class, + num_blocks=n, + stride_per_block=[s] + [1] * (n - 1), + in_channels=i, + out_channels=o, + **kwargs, + ) + ) + return ret + + +ResNetBlockBase = CNNBlockBase +""" +Alias for backward compatibiltiy. +""" + + +def make_stage(*args, **kwargs): + """ + Deprecated alias for backward compatibiltiy. + """ + return ResNet.make_stage(*args, **kwargs) + + +def _convert_ndarray_to_tensor(state_dict: Dict[str, Any]) -> None: + """ + In-place convert all numpy arrays in the state_dict to torch tensor. + Args: + state_dict (dict): a state-dict to be loaded to the model. + Will be modified. + """ + # model could be an OrderedDict with _metadata attribute + # (as returned by Pytorch's state_dict()). We should preserve these + # properties. + for k in list(state_dict.keys()): + v = state_dict[k] + if not isinstance(v, np.ndarray) and not isinstance(v, torch.Tensor): + raise ValueError( + "Unsupported type found in checkpoint! {}: {}".format(k, type(v)) + ) + if not isinstance(v, torch.Tensor): + state_dict[k] = torch.from_numpy(v) + + +@register_backbone +def get_resnet_backbone(cfg): + """ + Create a ResNet instance from config. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + res_cfg = cfg['MODEL']['BACKBONE']['RESNETS'] + + # need registration of new blocks/stems? + norm = res_cfg['NORM'] + stem = BasicStem( + in_channels=res_cfg['STEM_IN_CHANNELS'], + out_channels=res_cfg['STEM_OUT_CHANNELS'], + norm=norm, + ) + + # fmt: off + freeze_at = res_cfg['FREEZE_AT'] + out_features = res_cfg['OUT_FEATURES'] + depth = res_cfg['DEPTH'] + num_groups = res_cfg['NUM_GROUPS'] + width_per_group = res_cfg['WIDTH_PER_GROUP'] + bottleneck_channels = num_groups * width_per_group + in_channels = res_cfg['STEM_OUT_CHANNELS'] + out_channels = res_cfg['RES2_OUT_CHANNELS'] + stride_in_1x1 = res_cfg['STRIDE_IN_1X1'] + res5_dilation = res_cfg['RES5_DILATION'] + deform_on_per_stage = res_cfg['DEFORM_ON_PER_STAGE'] + deform_modulated = res_cfg['DEFORM_MODULATED'] + deform_num_groups = res_cfg['DEFORM_NUM_GROUPS'] + # fmt: on + assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + + if depth in [18, 34]: + assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" + assert not any( + deform_on_per_stage + ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34" + assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" + assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" + + stages = [] + + for idx, stage_idx in enumerate(range(2, 6)): + # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + } + # Use BasicBlock for R18 and R34. + if depth in [18, 34]: + stage_kargs["block_class"] = BasicBlock + else: + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + if deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + else: + stage_kargs["block_class"] = BottleneckBlock + blocks = ResNet.make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + backbone = ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at) + + if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: + filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] + with PathManager.open(filename, "rb") as f: + ckpt = pickle.load(f, encoding="latin1")['model'] + _convert_ndarray_to_tensor(ckpt) + ckpt.pop('stem.fc.weight') + ckpt.pop('stem.fc.bias') + backbone.load_state_dict(ckpt) + + return backbone diff --git a/xdecoder/backbone/swin.py b/xdecoder/backbone/swin.py new file mode 100755 index 0000000000000000000000000000000000000000..ed66e670a10762d7faf1e16bb2d6d80691182aca --- /dev/null +++ b/xdecoder/backbone/swin.py @@ -0,0 +1,892 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu, Yutong Lin, Yixuan Wei +# -------------------------------------------------------- + +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py +import logging +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from detectron2.modeling import Backbone, ShapeSpec +from detectron2.utils.file_io import PathManager + +from .registry import register_backbone + +logger = logging.getLogger(__name__) + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + # HACK model will not upsampling + # if min([H, W]) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + # self.shift_size = 0 + # self.window_size = min([H,W]) + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ).type(x.dtype) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(nn.Module): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_2tuple(pretrain_img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + def load_weights(self, pretrained_dict=None, pretrained_layers=[], verbose=True): + model_dict = self.state_dict() + pretrained_dict = { + k: v for k, v in pretrained_dict.items() + if k in model_dict.keys() + } + need_init_state_dict = {} + for k, v in pretrained_dict.items(): + need_init = ( + ( + k.split('.')[0] in pretrained_layers + or pretrained_layers[0] == '*' + ) + and 'relative_position_index' not in k + and 'attn_mask' not in k + ) + + if need_init: + # if verbose: + # logger.info(f'=> init {k} from {pretrained}') + + if 'relative_position_bias_table' in k and v.size() != model_dict[k].size(): + relative_position_bias_table_pretrained = v + relative_position_bias_table_current = model_dict[k] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if nH1 != nH2: + logger.info(f"Error in loading {k}, passing") + else: + if L1 != L2: + logger.info( + '=> load_pretrained: resized variant: {} to {}' + .format((L1, nH1), (L2, nH2)) + ) + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( + relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), + mode='bicubic') + v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) + + if 'absolute_pos_embed' in k and v.size() != model_dict[k].size(): + absolute_pos_embed_pretrained = v + absolute_pos_embed_current = model_dict[k] + _, L1, C1 = absolute_pos_embed_pretrained.size() + _, L2, C2 = absolute_pos_embed_current.size() + if C1 != C1: + logger.info(f"Error in loading {k}, passing") + else: + if L1 != L2: + logger.info( + '=> load_pretrained: resized variant: {} to {}' + .format((1, L1, C1), (1, L2, C2)) + ) + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) + absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) + absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( + absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') + v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2) + + need_init_state_dict[k] = v + self.load_state_dict(need_init_state_dict, strict=False) + + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs["res{}".format(i + 2)] = out + + if len(self.out_indices) == 0: + outs["res5"] = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +class D2SwinTransformer(SwinTransformer, Backbone): + def __init__(self, cfg, pretrain_img_size, patch_size, in_chans, embed_dim, + depths, num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, + drop_rate, attn_drop_rate, drop_path_rate, norm_layer, ape, + patch_norm, out_indices, use_checkpoint): + super().__init__( + pretrain_img_size, + patch_size, + in_chans, + embed_dim, + depths, + num_heads, + window_size, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + ape, + patch_norm, + out_indices, + use_checkpoint=use_checkpoint, + ) + + self._out_features = cfg['OUT_FEATURES'] + + self._out_feature_strides = { + "res2": 4, + "res3": 8, + "res4": 16, + "res5": 32, + } + self._out_feature_channels = { + "res2": self.num_features[0], + "res3": self.num_features[1], + "res4": self.num_features[2], + "res5": self.num_features[3], + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert ( + x.dim() == 4 + ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + y = super().forward(x) + for k in y.keys(): + if k in self._out_features: + outputs[k] = y[k] + return outputs + + def output_shape(self): + feature_names = list(set(self._out_feature_strides.keys()) & set(self._out_features)) + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in feature_names + } + + @property + def size_divisibility(self): + return 32 + + +@register_backbone +def get_swin_backbone(cfg): + swin_cfg = cfg['MODEL']['BACKBONE']['SWIN'] + + pretrain_img_size = swin_cfg['PRETRAIN_IMG_SIZE'] + patch_size = swin_cfg['PATCH_SIZE'] + in_chans = 3 + embed_dim = swin_cfg['EMBED_DIM'] + depths = swin_cfg['DEPTHS'] + num_heads = swin_cfg['NUM_HEADS'] + window_size = swin_cfg['WINDOW_SIZE'] + mlp_ratio = swin_cfg['MLP_RATIO'] + qkv_bias = swin_cfg['QKV_BIAS'] + qk_scale = swin_cfg['QK_SCALE'] + drop_rate = swin_cfg['DROP_RATE'] + attn_drop_rate = swin_cfg['ATTN_DROP_RATE'] + drop_path_rate = swin_cfg['DROP_PATH_RATE'] + norm_layer = nn.LayerNorm + ape = swin_cfg['APE'] + patch_norm = swin_cfg['PATCH_NORM'] + use_checkpoint = swin_cfg['USE_CHECKPOINT'] + out_indices = swin_cfg.get('OUT_INDICES', [0,1,2,3]) + + swin = D2SwinTransformer( + swin_cfg, + pretrain_img_size, + patch_size, + in_chans, + embed_dim, + depths, + num_heads, + window_size, + mlp_ratio, + qkv_bias, + qk_scale, + drop_rate, + attn_drop_rate, + drop_path_rate, + norm_layer, + ape, + patch_norm, + out_indices, + use_checkpoint=use_checkpoint, + ) + + if cfg['MODEL']['BACKBONE']['LOAD_PRETRAINED'] is True: + filename = cfg['MODEL']['BACKBONE']['PRETRAINED'] + with PathManager.open(filename, "rb") as f: + ckpt = torch.load(f, map_location=cfg['device'])['model'] + swin.load_weights(ckpt, swin_cfg.get('PRETRAINED_LAYERS', ['*']), cfg['VERBOSE']) + + return swin \ No newline at end of file diff --git a/xdecoder/body/__init__.py b/xdecoder/body/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..5b5e32900735a900cc4daef04bb5038cf9f178c9 --- /dev/null +++ b/xdecoder/body/__init__.py @@ -0,0 +1 @@ +from .build import build_xdecoder_head \ No newline at end of file diff --git a/xdecoder/body/__pycache__/__init__.cpython-38.pyc b/xdecoder/body/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35cff7617c3dd01830c345f3ab9a30ceedd163ec Binary files /dev/null and b/xdecoder/body/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/body/__pycache__/build.cpython-38.pyc b/xdecoder/body/__pycache__/build.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..429a977683b969e4f9d12b740008284a99c80745 Binary files /dev/null and b/xdecoder/body/__pycache__/build.cpython-38.pyc differ diff --git a/xdecoder/body/__pycache__/registry.cpython-38.pyc b/xdecoder/body/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb72baf7a11d3db535d70f058d575405210b5549 Binary files /dev/null and b/xdecoder/body/__pycache__/registry.cpython-38.pyc differ diff --git a/xdecoder/body/__pycache__/transformer_blocks.cpython-38.pyc b/xdecoder/body/__pycache__/transformer_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b0757911c4be23ec8c407ba123ba293c3181fff Binary files /dev/null and b/xdecoder/body/__pycache__/transformer_blocks.cpython-38.pyc differ diff --git a/xdecoder/body/__pycache__/xdecoder_head.cpython-38.pyc b/xdecoder/body/__pycache__/xdecoder_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11d0996ec01c3eeaae2603f0dcbf55515741b69a Binary files /dev/null and b/xdecoder/body/__pycache__/xdecoder_head.cpython-38.pyc differ diff --git a/xdecoder/body/build.py b/xdecoder/body/build.py new file mode 100755 index 0000000000000000000000000000000000000000..fb35e4cc266c64418f4b21e9d95c7844417a2a56 --- /dev/null +++ b/xdecoder/body/build.py @@ -0,0 +1,13 @@ +from .registry import model_entrypoints +from .registry import is_model + +from .xdecoder_head import * + + +def build_xdecoder_head(config, *args, **kwargs): + model_name = config['MODEL']['HEAD'] + if not is_model(model_name): + raise ValueError(f'Unkown model: {model_name}') + + body = model_entrypoints(model_name)(config, *args, **kwargs) + return body \ No newline at end of file diff --git a/xdecoder/body/decoder/__init__.py b/xdecoder/body/decoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..bbce50aad955329e5cba93e1d4d2f25e3cf694c7 --- /dev/null +++ b/xdecoder/body/decoder/__init__.py @@ -0,0 +1 @@ +from .build import build_decoder \ No newline at end of file diff --git a/xdecoder/body/decoder/__pycache__/__init__.cpython-38.pyc b/xdecoder/body/decoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17e2763dc8636feb56c6ab25c6653ea7b0537219 Binary files /dev/null and b/xdecoder/body/decoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/body/decoder/__pycache__/build.cpython-38.pyc b/xdecoder/body/decoder/__pycache__/build.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae14b364d608b8e3ec8c10ef853f12724c11aca9 Binary files /dev/null and b/xdecoder/body/decoder/__pycache__/build.cpython-38.pyc differ diff --git a/xdecoder/body/decoder/__pycache__/registry.cpython-38.pyc b/xdecoder/body/decoder/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b3d678f6fe8e24ed217f9a9591bac543e9d59ad Binary files /dev/null and b/xdecoder/body/decoder/__pycache__/registry.cpython-38.pyc differ diff --git a/xdecoder/body/decoder/__pycache__/xdecoder.cpython-38.pyc b/xdecoder/body/decoder/__pycache__/xdecoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dff5c512d382b351ca8e967f24d06d6383a791f Binary files /dev/null and b/xdecoder/body/decoder/__pycache__/xdecoder.cpython-38.pyc differ diff --git a/xdecoder/body/decoder/build.py b/xdecoder/body/decoder/build.py new file mode 100755 index 0000000000000000000000000000000000000000..c5c9be6f177885315a53845a624175430fa48ff1 --- /dev/null +++ b/xdecoder/body/decoder/build.py @@ -0,0 +1,12 @@ +from .registry import model_entrypoints +from .registry import is_model + +from .xdecoder import * + +def build_decoder(config, *args, **kwargs): + model_name = config['MODEL']['DECODER']['NAME'] + + if not is_model(model_name): + raise ValueError(f'Unkown model: {model_name}') + + return model_entrypoints(model_name)(config, *args, **kwargs) \ No newline at end of file diff --git a/xdecoder/body/decoder/registry.py b/xdecoder/body/decoder/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..bd9a7453d5bace3cdd892226f2f40c1a0be1fdb6 --- /dev/null +++ b/xdecoder/body/decoder/registry.py @@ -0,0 +1,13 @@ +_model_entrypoints = {} + +def register_decoder(fn): + module_name_split = fn.__module__.split('.') + model_name = module_name_split[-1] + _model_entrypoints[model_name] = fn + return fn + +def model_entrypoints(model_name): + return _model_entrypoints[model_name] + +def is_model(model_name): + return model_name in _model_entrypoints \ No newline at end of file diff --git a/xdecoder/body/decoder/tmp.py b/xdecoder/body/decoder/tmp.py new file mode 100644 index 0000000000000000000000000000000000000000..d449b4e8fb6ad90b58f6aad20c410450572f647c --- /dev/null +++ b/xdecoder/body/decoder/tmp.py @@ -0,0 +1,664 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py +import logging +from typing import Optional + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +from timm.models.layers import trunc_normal_ +from detectron2.layers import Conv2d +import fvcore.nn.weight_init as weight_init + +from .registry import register_decoder +from ...utils import configurable +from ...modules import PositionEmbeddingSine + +from image2html.visualizer import VL + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt, avg_attn + + def forward_pre(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout(tgt2) + + return tgt, avg_attn + + def forward(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class MultiScaleMaskedTransformerDecoder(nn.Module): + + _version = 2 + + @configurable + def __init__( + self, + lang_encoder: nn.Module, + in_channels, + mask_classification=True, + *, + hidden_dim: int, + dim_proj: int, + num_queries: int, + contxt_len: int, + nheads: int, + dim_feedforward: int, + dec_layers: int, + pre_norm: bool, + mask_dim: int, + task_switch: dict, + captioning_step: int, + enforce_input_project: bool, + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv even if input + channels and hidden dim is identical + """ + super().__init__() + assert mask_classification, "Only support mask classification model" + self.mask_classification = mask_classification + + # positional encoding + N_steps = hidden_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + # define Transformer decoder here + self.num_heads = nheads + self.num_layers = dec_layers + self.contxt_len = contxt_len + self.transformer_self_attention_layers = nn.ModuleList() + self.transformer_cross_attention_layers = nn.ModuleList() + self.transformer_ffn_layers = nn.ModuleList() + + for _ in range(self.num_layers): + self.transformer_self_attention_layers.append( + SelfAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.transformer_cross_attention_layers.append( + CrossAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.transformer_ffn_layers.append( + FFNLayer( + d_model=hidden_dim, + dim_feedforward=dim_feedforward, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + + self.num_queries = num_queries + # learnable query features + self.query_feat = nn.Embedding(num_queries, hidden_dim) + # learnable query p.e. + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + # level embedding (we always use 3 scales) + self.num_feature_levels = 3 + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + self.input_proj = nn.ModuleList() + + for _ in range(self.num_feature_levels): + if in_channels != hidden_dim or enforce_input_project: + self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) + weight_init.c2_xavier_fill(self.input_proj[-1]) + else: + self.input_proj.append(nn.Sequential()) + + self.task_switch = task_switch + + # output FFNs + self.lang_encoder = lang_encoder + if self.task_switch['mask']: + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) + trunc_normal_(self.class_embed, std=.02) + + if task_switch['bbox']: + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + + # Caption Project and query + if task_switch['captioning']: + self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) + trunc_normal_(self.caping_embed, std=.02) + self.query_feat_caping = nn.Embedding(contxt_len, hidden_dim) + self.captioning_step = captioning_step + + # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query + self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool() + self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query. + self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token. + self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query. + self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query. + self.register_buffer("self_attn_mask", self_attn_mask) + + + @classmethod + def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra): + ret = {} + + ret["lang_encoder"] = lang_encoder + ret["in_channels"] = in_channels + ret["mask_classification"] = mask_classification + + enc_cfg = cfg['MODEL']['ENCODER'] + dec_cfg = cfg['MODEL']['DECODER'] + + ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] + ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] + ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] + ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] + + # Transformer parameters: + ret["nheads"] = dec_cfg['NHEADS'] + ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] + + # NOTE: because we add learnable query features which requires supervision, + # we add minus 1 to decoder layers to be consistent with our loss + # implementation: that is, number of auxiliary losses is always + # equal to number of decoder layers. With learnable query features, the number of + # auxiliary losses equals number of decoders plus 1. + assert dec_cfg['DEC_LAYERS'] >= 1 + ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1 + ret["pre_norm"] = dec_cfg['PRE_NORM'] + ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] + ret["mask_dim"] = enc_cfg['MASK_DIM'] + + ret["task_switch"] = extra['task_switch'] + ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50) + + return ret + + def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): + if task == 'captioning_infer': + return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + + predictions_class = [] + predictions_mask = [] + predictions_bbox = [] + predictions_caption = [] + predictions_captioning = [] + + self_tgt_mask = None + if self.training and task == 'vlp' and self.task_switch['captioning']: + output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token. + caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output + query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning. + self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1) + elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1) + grounding_tokens = extra['grounding_tokens'] + _grounding_tokens = grounding_tokens.detach().clone() + # initialize with negative attention at the beginning. + pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1) + pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask + pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other + self_tgt_mask = pad_tgt_mask + output = torch.cat((output, output[:-1]), dim=0) + query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding + else: + self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1) + + # prediction heads on learnable query features + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task) + attn_mask = results["attn_mask"] + predictions_class.append(results["outputs_class"]) + predictions_mask.append(results["outputs_mask"]) + predictions_bbox.append(results["outputs_bbox"]) + predictions_caption.append(results["outputs_caption"]) + predictions_captioning.append(results["outputs_captionting"]) + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + if self.training and task == 'vlp' and self.task_switch['captioning']: + attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1) + # attention: cross-attention first + output, avg_attn = self.transformer_cross_attention_layers[i]( + output, src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], query_pos=query_embed + ) + + if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + output = torch.cat((output, _grounding_tokens), dim=0) + query_embed = torch.cat((query_embed, grounding_tokens), dim=0) + + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=self_tgt_mask, + tgt_key_padding_mask=None, + query_pos=query_embed + ) + + # FFN + output = self.transformer_ffn_layers[i]( + output + ) + + if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding'] \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + _grounding_tokens = output[-len(_grounding_tokens):] + output = output[:-len(_grounding_tokens)] + query_embed = query_embed[:-len(_grounding_tokens)] + + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task) + attn_mask = results["attn_mask"] + predictions_class.append(results["outputs_class"]) + predictions_mask.append(results["outputs_mask"]) + predictions_bbox.append(results["outputs_bbox"]) + predictions_caption.append(results["outputs_caption"]) + predictions_captioning.append(results["outputs_captionting"]) + + assert len(predictions_class) == self.num_layers + 1 + if task == 'vlp': + out = {'pred_captionings': predictions_captioning[-1], + 'pred_captions': predictions_caption[-1], + 'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]} + return out + else: + out = { + 'pred_logits': predictions_class[-1], + 'pred_masks': predictions_mask[-1], + 'pred_boxes': predictions_bbox[-1], + 'pred_captions': predictions_caption[-1], + 'aux_outputs': self._set_aux_loss( + predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption + ) + } + return out + + def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}): + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + caping_lang_token = extra['start_token'].repeat(bs, 1) + query_feat_caping = self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1) + + # prepare token embedding for evaluation + token_embs = self.lang_encoder.lang_encoder.token_embedding.weight + # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7) + + for cap_idx in range(0, self.captioning_step): + caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1) + query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning. + output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token. + + # prediction heads on learnable query features + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task) + attn_mask = results["attn_mask"] + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1) + self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1) + + # attention: cross-attention first + output, avg_attn = self.transformer_cross_attention_layers[i]( + output, src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], query_pos=query_embed + ) + + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=self_tgt_mask, + tgt_key_padding_mask=None, + query_pos=query_embed + ) + + # FFN + output = self.transformer_ffn_layers[i]( + output + ) + + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task) + attn_mask = results["attn_mask"] + + pred_captions_gen = results['outputs_captionting'] + # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7) + pred_captions_gen = pred_captions_gen @ token_embs.t() + caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1] + + out = {'pred_captionings': caping_lang_token, + 'pred_texts': self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=True)} + return out + + + def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'): + decoder_output = self.decoder_norm(output) + decoder_output = decoder_output.transpose(0, 1) + + # extract image captioning token from decoder output. + if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'): + outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed + else: + outputs_captionting = None + + # recompute class token output. + norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7) + obj_token = norm_decoder_output[:,:self.num_queries-1] + cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries] + + sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token. + cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True) + + if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1) + else: + decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1) + + # compute class, mask and bbox. + class_embed = decoder_output @ self.class_embed + # HACK do not compute similarity if mask is not on + outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training) or (task == 'openimage'))) + + if self.task_switch['mask'] or self.task_switch['openimage']['mask']: + mask_embed = self.mask_embed(decoder_output) + outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + + # NOTE: prediction is of higher-resolution + # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] + attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) + + # must use bool type + # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. + attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() + attn_mask = attn_mask.detach() + + # NOTE: fill False for cls token (JY) + attn_mask[:, self.num_queries:self.num_queries+1].fill_(False) + else: + outputs_mask = None + attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool() + + outputs_bbox = [None for i in range(len(decoder_output))] + if self.task_switch['bbox']: + outputs_bbox = self.bbox_embed(decoder_output) + + outputs_caption = None + if self.task_switch['caption']: + outputs_caption = class_embed + + + results = { + "outputs_class": outputs_class, + "outputs_mask": outputs_mask, + "outputs_bbox": outputs_bbox, + "attn_mask": attn_mask, + "outputs_caption": outputs_caption, + "outputs_captionting": outputs_captionting, + } + return results + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + if self.mask_classification: + return [ + {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d} + for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1]) + ] + else: + return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] + + +@register_decoder +def get_masked_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra): + return MultiScaleMaskedTransformerDecoder(cfg, in_channels, lang_encoder, mask_classification, extra) \ No newline at end of file diff --git a/xdecoder/body/decoder/xdecoder.py b/xdecoder/body/decoder/xdecoder.py new file mode 100755 index 0000000000000000000000000000000000000000..7e0543deaf932963c40bf414f904b8ef82f8fc63 --- /dev/null +++ b/xdecoder/body/decoder/xdecoder.py @@ -0,0 +1,700 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py + +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu), Jianwei Yang (jianwyan@microsoft.com) +# -------------------------------------------------------- + + +import logging +from typing import Optional + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +from timm.models.layers import trunc_normal_ +from detectron2.layers import Conv2d +import fvcore.nn.weight_init as weight_init + +from .registry import register_decoder +from ...utils import configurable +from ...modules import PositionEmbeddingSine + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt, avg_attn + + def forward_pre(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout(tgt2) + + return tgt, avg_attn + + def forward(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class MultiScaleMaskedTransformerDecoder(nn.Module): + + _version = 2 + + @configurable + def __init__( + self, + lang_encoder: nn.Module, + in_channels, + mask_classification=True, + *, + hidden_dim: int, + dim_proj: int, + num_queries: int, + contxt_len: int, + nheads: int, + dim_feedforward: int, + dec_layers: int, + pre_norm: bool, + mask_dim: int, + task_switch: dict, + captioning_step: int, + enforce_input_project: bool, + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv even if input + channels and hidden dim is identical + """ + super().__init__() + assert mask_classification, "Only support mask classification model" + self.mask_classification = mask_classification + + # positional encoding + N_steps = hidden_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + # define Transformer decoder here + self.num_heads = nheads + self.num_layers = dec_layers + self.contxt_len = contxt_len + self.transformer_self_attention_layers = nn.ModuleList() + self.transformer_cross_attention_layers = nn.ModuleList() + self.transformer_ffn_layers = nn.ModuleList() + + for _ in range(self.num_layers): + self.transformer_self_attention_layers.append( + SelfAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.transformer_cross_attention_layers.append( + CrossAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.transformer_ffn_layers.append( + FFNLayer( + d_model=hidden_dim, + dim_feedforward=dim_feedforward, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + + self.num_queries = num_queries + # learnable query features + self.query_feat = nn.Embedding(num_queries, hidden_dim) + # learnable query p.e. + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + # level embedding (we always use 3 scales) + self.num_feature_levels = 3 + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + self.input_proj = nn.ModuleList() + + for _ in range(self.num_feature_levels): + if in_channels != hidden_dim or enforce_input_project: + self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) + weight_init.c2_xavier_fill(self.input_proj[-1]) + else: + self.input_proj.append(nn.Sequential()) + + self.task_switch = task_switch + + # output FFNs + self.lang_encoder = lang_encoder + if self.task_switch['mask']: + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) + trunc_normal_(self.class_embed, std=.02) + + if task_switch['bbox']: + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + + # Caption Project and query + if task_switch['captioning']: + self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) + trunc_normal_(self.caping_embed, std=.02) + # self.query_feat_caping = nn.Embedding(contxt_len, hidden_dim) + self.pos_embed_caping = nn.Embedding(contxt_len, hidden_dim) + self.captioning_step = captioning_step + + # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query + self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool() + self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query. + self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token. + self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query. + self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query. + self.register_buffer("self_attn_mask", self_attn_mask) + + + @classmethod + def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra): + ret = {} + + ret["lang_encoder"] = lang_encoder + ret["in_channels"] = in_channels + ret["mask_classification"] = mask_classification + + enc_cfg = cfg['MODEL']['ENCODER'] + dec_cfg = cfg['MODEL']['DECODER'] + + ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] + ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] + ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] + ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] + + # Transformer parameters: + ret["nheads"] = dec_cfg['NHEADS'] + ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] + + # NOTE: because we add learnable query features which requires supervision, + # we add minus 1 to decoder layers to be consistent with our loss + # implementation: that is, number of auxiliary losses is always + # equal to number of decoder layers. With learnable query features, the number of + # auxiliary losses equals number of decoders plus 1. + assert dec_cfg['DEC_LAYERS'] >= 1 + ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1 + ret["pre_norm"] = dec_cfg['PRE_NORM'] + ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] + ret["mask_dim"] = enc_cfg['MASK_DIM'] + + ret["task_switch"] = extra['task_switch'] + ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50) + + return ret + + def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): + if task == 'captioning_infer': + return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + + predictions_class = [] + predictions_mask = [] + predictions_bbox = [] + predictions_caption = [] + predictions_captioning = [] + + self_tgt_mask = None + if self.training and task == 'vlp' and self.task_switch['captioning']: + # output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token. + caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output + _caping_lang_embed = caping_lang_embed.detach().clone() + output = torch.cat((output, _caping_lang_embed), dim=0) # concat object query, class token and caption token. + caping_lang_embed += self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1) + query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning. + self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1) + elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1) + grounding_tokens = extra['grounding_tokens'] + _grounding_tokens = grounding_tokens.detach().clone() + # initialize with negative attention at the beginning. + pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1) + pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask + pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other + self_tgt_mask = pad_tgt_mask + output = torch.cat((output, output[:-1]), dim=0) + query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding + else: + self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1) + + # prediction heads on learnable query features + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task) + attn_mask = results["attn_mask"] + predictions_class.append(results["outputs_class"]) + predictions_mask.append(results["outputs_mask"]) + predictions_bbox.append(results["outputs_bbox"]) + predictions_caption.append(results["outputs_caption"]) + predictions_captioning.append(results["outputs_captionting"]) + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + if self.training and task == 'vlp' and self.task_switch['captioning']: + attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1) + # attention: cross-attention first + output, avg_attn = self.transformer_cross_attention_layers[i]( + output, src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], query_pos=query_embed + ) + + if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + output = torch.cat((output, _grounding_tokens), dim=0) + query_embed = torch.cat((query_embed, grounding_tokens), dim=0) + + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=self_tgt_mask, + tgt_key_padding_mask=None, + query_pos=query_embed + ) + + # FFN + output = self.transformer_ffn_layers[i]( + output + ) + + if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding'] \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + _grounding_tokens = output[-len(_grounding_tokens):] + output = output[:-len(_grounding_tokens)] + query_embed = query_embed[:-len(_grounding_tokens)] + + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task) + attn_mask = results["attn_mask"] + predictions_class.append(results["outputs_class"]) + predictions_mask.append(results["outputs_mask"]) + predictions_bbox.append(results["outputs_bbox"]) + predictions_caption.append(results["outputs_caption"]) + predictions_captioning.append(results["outputs_captionting"]) + + assert len(predictions_class) == self.num_layers + 1 + if task == 'vlp': + out = {'pred_captionings': predictions_captioning[-1], + 'pred_captions': predictions_caption[-1], + 'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]} + return out + else: + out = { + 'pred_logits': predictions_class[-1], + 'pred_masks': predictions_mask[-1], + 'pred_boxes': predictions_bbox[-1], + 'pred_captions': predictions_caption[-1], + 'aux_outputs': self._set_aux_loss( + predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption + ) + } + return out + + def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}): + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + caping_lang_token = extra['start_token'].repeat(bs, 1) + start_id = 0 + if 'token' in extra: + caping_lang_token[:,:len(extra['token'][0])] = extra['token'] + start_id = len(extra['token'][0])-1 + # query_feat_caping = self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1) + pos_embed_caping = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1) + # prepare token embedding for evaluation + token_embs = self.lang_encoder.lang_encoder.token_embedding.weight + # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7) + + for cap_idx in range(start_id, self.captioning_step): + caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1) + output = torch.cat((query_feat, caping_lang_embed), dim=0) # concat object query, class token and caption token. + caping_lang_embed += pos_embed_caping + query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning. + # output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token. + + # prediction heads on learnable query features + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task) + attn_mask = results["attn_mask"] + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1) + self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1) + + if extra['captioning_mask'] is not None: + bs,nq,wh = attn_mask.shape + assert bs==self.num_heads, "Only support single image referring captioning." + cap_mask = extra['captioning_mask'] + attn_mask = attn_mask.reshape(bs,nq,size_list[i%3][0],size_list[i%3][1]) + cap_mask = F.interpolate(cap_mask[None,].float(), size_list[i%3], mode='nearest').bool()[0,0] + attn_mask[:,self.num_queries:, cap_mask] = True + attn_mask = attn_mask.reshape(bs,nq,wh) + + # attention: cross-attention first + output, avg_attn = self.transformer_cross_attention_layers[i]( + output, src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], query_pos=query_embed + ) + + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=self_tgt_mask, + tgt_key_padding_mask=None, + query_pos=query_embed + ) + + # FFN + output = self.transformer_ffn_layers[i]( + output + ) + + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task) + attn_mask = results["attn_mask"] + + pred_captions_gen = results['outputs_captionting'] + # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7) + pred_captions_gen = pred_captions_gen @ token_embs.t() + caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1] + + texts = self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=False) + texts_new = [] + + for x in texts: + x = x.split('<|endoftext|>')[0] + x = x.replace('<|endoftext|>','') + x = x.replace('<|startoftext|>','') + x = x.strip() + texts_new.append(x) + + out = {'pred_captionings': caping_lang_token, + 'pred_texts': texts_new} + return out + + + def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'): + decoder_output = self.decoder_norm(output) + decoder_output = decoder_output.transpose(0, 1) + + # extract image captioning token from decoder output. + if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'): + outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed + else: + outputs_captionting = None + + # recompute class token output. + norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7) + obj_token = norm_decoder_output[:,:self.num_queries-1] + cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries] + + sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token. + cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True) + + if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1) + else: + decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1) + + # compute class, mask and bbox. + class_embed = decoder_output @ self.class_embed + # HACK do not compute similarity if mask is not on + outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training) or (task == 'openimage'))) + + if self.task_switch['mask'] or self.task_switch['openimage']['mask']: + mask_embed = self.mask_embed(decoder_output) + outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + + # NOTE: prediction is of higher-resolution + # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] + attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) + + # must use bool type + # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. + attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() + attn_mask = attn_mask.detach() + + # NOTE: fill False for cls token (JY) + attn_mask[:, self.num_queries:self.num_queries+1].fill_(False) + else: + outputs_mask = None + attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool() + + outputs_bbox = [None for i in range(len(decoder_output))] + if self.task_switch['bbox']: + outputs_bbox = self.bbox_embed(decoder_output) + + outputs_caption = None + if self.task_switch['caption']: + outputs_caption = class_embed + + + results = { + "outputs_class": outputs_class, + "outputs_mask": outputs_mask, + "outputs_bbox": outputs_bbox, + "attn_mask": attn_mask, + "outputs_caption": outputs_caption, + "outputs_captionting": outputs_captionting, + } + return results + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + if self.mask_classification: + return [ + {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d} + for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1]) + ] + else: + return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] + + +@register_decoder +def get_masked_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra): + return MultiScaleMaskedTransformerDecoder(cfg, in_channels, lang_encoder, mask_classification, extra) \ No newline at end of file diff --git a/xdecoder/body/decoder/xdecoder2.py b/xdecoder/body/decoder/xdecoder2.py new file mode 100644 index 0000000000000000000000000000000000000000..e99d4623b2e987a66650db71c4a378a0ebaf241a --- /dev/null +++ b/xdecoder/body/decoder/xdecoder2.py @@ -0,0 +1,700 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py + +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Xueyan Zou (xueyan@cs.wisc.edu), Jianwei Yang (jianwyan@microsoft.com) +# -------------------------------------------------------- + + +import logging +from typing import Optional + +import torch +from torch import nn, Tensor +from torch.nn import functional as F + +from timm.models.layers import trunc_normal_ +from detectron2.layers import Conv2d +import fvcore.nn.weight_init as weight_init + +from .registry import register_decoder +from ...utils import configurable +from ...modules import PositionEmbeddingSine + + +class SelfAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + + return tgt + + def forward_pre(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout(tgt2) + + return tgt + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + return self.forward_post(tgt, tgt_mask, + tgt_key_padding_mask, query_pos) + + +class CrossAttentionLayer(nn.Module): + + def __init__(self, d_model, nhead, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt, avg_attn + + def forward_pre(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm(tgt) + tgt2, avg_attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask) + tgt = tgt + self.dropout(tgt2) + + return tgt, avg_attn + + def forward(self, tgt, memory, + memory_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, memory_mask, + memory_key_padding_mask, pos, query_pos) + + +class FFNLayer(nn.Module): + + def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, + activation="relu", normalize_before=False): + super().__init__() + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm = nn.LayerNorm(d_model) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt): + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout(tgt2) + tgt = self.norm(tgt) + return tgt + + def forward_pre(self, tgt): + tgt2 = self.norm(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout(tgt2) + return tgt + + def forward(self, tgt): + if self.normalize_before: + return self.forward_pre(tgt) + return self.forward_post(tgt) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class MultiScaleMaskedTransformerDecoder(nn.Module): + + _version = 2 + + @configurable + def __init__( + self, + lang_encoder: nn.Module, + in_channels, + mask_classification=True, + *, + hidden_dim: int, + dim_proj: int, + num_queries: int, + contxt_len: int, + nheads: int, + dim_feedforward: int, + dec_layers: int, + pre_norm: bool, + mask_dim: int, + task_switch: dict, + captioning_step: int, + enforce_input_project: bool, + ): + """ + NOTE: this interface is experimental. + Args: + in_channels: channels of the input features + mask_classification: whether to add mask classifier or not + num_classes: number of classes + hidden_dim: Transformer feature dimension + num_queries: number of queries + nheads: number of heads + dim_feedforward: feature dimension in feedforward network + enc_layers: number of Transformer encoder layers + dec_layers: number of Transformer decoder layers + pre_norm: whether to use pre-LayerNorm or not + mask_dim: mask feature dimension + enforce_input_project: add input project 1x1 conv even if input + channels and hidden dim is identical + """ + super().__init__() + assert mask_classification, "Only support mask classification model" + self.mask_classification = mask_classification + + # positional encoding + N_steps = hidden_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + # define Transformer decoder here + self.num_heads = nheads + self.num_layers = dec_layers + self.contxt_len = contxt_len + self.transformer_self_attention_layers = nn.ModuleList() + self.transformer_cross_attention_layers = nn.ModuleList() + self.transformer_ffn_layers = nn.ModuleList() + + for _ in range(self.num_layers): + self.transformer_self_attention_layers.append( + SelfAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.transformer_cross_attention_layers.append( + CrossAttentionLayer( + d_model=hidden_dim, + nhead=nheads, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.transformer_ffn_layers.append( + FFNLayer( + d_model=hidden_dim, + dim_feedforward=dim_feedforward, + dropout=0.0, + normalize_before=pre_norm, + ) + ) + + self.decoder_norm = nn.LayerNorm(hidden_dim) + + self.num_queries = num_queries + # learnable query features + self.query_feat = nn.Embedding(num_queries, hidden_dim) + # learnable query p.e. + self.query_embed = nn.Embedding(num_queries, hidden_dim) + + # level embedding (we always use 3 scales) + self.num_feature_levels = 3 + self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim) + self.input_proj = nn.ModuleList() + + for _ in range(self.num_feature_levels): + if in_channels != hidden_dim or enforce_input_project: + self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1)) + weight_init.c2_xavier_fill(self.input_proj[-1]) + else: + self.input_proj.append(nn.Sequential()) + + self.task_switch = task_switch + + # output FFNs + self.lang_encoder = lang_encoder + if self.task_switch['mask']: + self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) + + self.class_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) + trunc_normal_(self.class_embed, std=.02) + + if task_switch['bbox']: + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + + # Caption Project and query + if task_switch['captioning']: + self.caping_embed = nn.Parameter(torch.empty(hidden_dim, dim_proj)) + trunc_normal_(self.caping_embed, std=.02) + self.query_feat_caping = nn.Embedding(contxt_len, hidden_dim) + # self.pos_embed_caping = nn.Embedding(contxt_len, hidden_dim) + self.captioning_step = captioning_step + + # register self_attn_mask to avoid information leakage, it includes interaction between object query, class query and caping query + self_attn_mask = torch.zeros((1, num_queries + contxt_len, num_queries + contxt_len)).bool() + self_attn_mask[:, :num_queries, num_queries:] = True # object+class query does not attend with caption query. + self_attn_mask[:, num_queries:, num_queries:] = torch.triu(torch.ones((1, contxt_len, contxt_len)), diagonal=1).bool() # caption query only attend with previous token. + self_attn_mask[:, :num_queries-1, num_queries-1:num_queries] = True # object query does not attend with class query. + self_attn_mask[:, num_queries-1:num_queries, :num_queries-1] = True # class query does not attend with object query. + self.register_buffer("self_attn_mask", self_attn_mask) + + + @classmethod + def from_config(cls, cfg, in_channels, lang_encoder, mask_classification, extra): + ret = {} + + ret["lang_encoder"] = lang_encoder + ret["in_channels"] = in_channels + ret["mask_classification"] = mask_classification + + enc_cfg = cfg['MODEL']['ENCODER'] + dec_cfg = cfg['MODEL']['DECODER'] + + ret["hidden_dim"] = dec_cfg['HIDDEN_DIM'] + ret["dim_proj"] = cfg['MODEL']['DIM_PROJ'] + ret["num_queries"] = dec_cfg['NUM_OBJECT_QUERIES'] + ret["contxt_len"] = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] + + # Transformer parameters: + ret["nheads"] = dec_cfg['NHEADS'] + ret["dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] + + # NOTE: because we add learnable query features which requires supervision, + # we add minus 1 to decoder layers to be consistent with our loss + # implementation: that is, number of auxiliary losses is always + # equal to number of decoder layers. With learnable query features, the number of + # auxiliary losses equals number of decoders plus 1. + assert dec_cfg['DEC_LAYERS'] >= 1 + ret["dec_layers"] = dec_cfg['DEC_LAYERS'] - 1 + ret["pre_norm"] = dec_cfg['PRE_NORM'] + ret["enforce_input_project"] = dec_cfg['ENFORCE_INPUT_PROJ'] + ret["mask_dim"] = enc_cfg['MASK_DIM'] + + ret["task_switch"] = extra['task_switch'] + ret["captioning_step"] = dec_cfg['CAPTIONING'].get('STEP', 50) + + return ret + + def forward(self, x, mask_features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): + if task == 'captioning_infer': + return self.forward_captioning(x, mask_features, mask=mask, target_queries=target_queries, target_vlp=target_vlp, task=task, extra=extra) + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + + predictions_class = [] + predictions_mask = [] + predictions_bbox = [] + predictions_caption = [] + predictions_captioning = [] + + self_tgt_mask = None + if self.training and task == 'vlp' and self.task_switch['captioning']: + output = torch.cat((output, self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1)), dim=0) # concat object query, class token and caption token. + caping_lang_embed = torch.cat([caption['caption_tokens'] for caption in target_vlp], dim=0).transpose(0, 1) # language output + # _caping_lang_embed = caping_lang_embed.detach().clone() + # output = torch.cat((output, _caping_lang_embed), dim=0) # concat object query, class token and caption token. + # caping_lang_embed += self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1) + query_embed = torch.cat((query_embed, caping_lang_embed), dim=0) # may not add at the beginning. + self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1) + elif (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1) + grounding_tokens = extra['grounding_tokens'] + _grounding_tokens = grounding_tokens.detach().clone() + # initialize with negative attention at the beginning. + pad_tgt_mask = torch.ones((1, self.num_queries + (self.num_queries-1) + len(grounding_tokens), self.num_queries + (self.num_queries-1) + len(grounding_tokens)), device=self_tgt_mask.device).bool().repeat(output.shape[1]*self.num_heads, 1, 1) + pad_tgt_mask[:,:self.num_queries,:self.num_queries] = self_tgt_mask + pad_tgt_mask[:,self.num_queries:,self.num_queries:] = False # grounding tokens could attend with eatch other + self_tgt_mask = pad_tgt_mask + output = torch.cat((output, output[:-1]), dim=0) + query_embed = torch.cat((query_embed, query_embed[:-1]), dim=0) # also pad language embdding to fix embedding + else: + self_tgt_mask = self.self_attn_mask[:,:self.num_queries,:self.num_queries].repeat(output.shape[1]*self.num_heads, 1, 1) + + # prediction heads on learnable query features + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task) + attn_mask = results["attn_mask"] + predictions_class.append(results["outputs_class"]) + predictions_mask.append(results["outputs_mask"]) + predictions_bbox.append(results["outputs_bbox"]) + predictions_caption.append(results["outputs_caption"]) + predictions_captioning.append(results["outputs_captionting"]) + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + + if self.training and task == 'vlp' and self.task_switch['captioning']: + attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1) + # attention: cross-attention first + output, avg_attn = self.transformer_cross_attention_layers[i]( + output, src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], query_pos=query_embed + ) + + if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + output = torch.cat((output, _grounding_tokens), dim=0) + query_embed = torch.cat((query_embed, grounding_tokens), dim=0) + + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=self_tgt_mask, + tgt_key_padding_mask=None, + query_pos=query_embed + ) + + # FFN + output = self.transformer_ffn_layers[i]( + output + ) + + if ((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding'] \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + _grounding_tokens = output[-len(_grounding_tokens):] + output = output[:-len(_grounding_tokens)] + query_embed = query_embed[:-len(_grounding_tokens)] + + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task) + attn_mask = results["attn_mask"] + predictions_class.append(results["outputs_class"]) + predictions_mask.append(results["outputs_mask"]) + predictions_bbox.append(results["outputs_bbox"]) + predictions_caption.append(results["outputs_caption"]) + predictions_captioning.append(results["outputs_captionting"]) + + assert len(predictions_class) == self.num_layers + 1 + if task == 'vlp': + out = {'pred_captionings': predictions_captioning[-1], + 'pred_captions': predictions_caption[-1], + 'aux_outputs': [{'pred_captionings': x, 'pred_captions': y } for x, y in zip(predictions_captioning[:-1], predictions_caption[:-1])]} + return out + else: + out = { + 'pred_logits': predictions_class[-1], + 'pred_masks': predictions_mask[-1], + 'pred_boxes': predictions_bbox[-1], + 'pred_captions': predictions_caption[-1], + 'aux_outputs': self._set_aux_loss( + predictions_class if self.mask_classification else None, predictions_mask, predictions_bbox, predictions_caption + ) + } + return out + + def forward_captioning(self, x, mask_features, mask = None, target_queries = None, target_vlp = None, task='seg', extra={}): + # x is a list of multi-scale feature + assert len(x) == self.num_feature_levels + src = [] + pos = [] + size_list = [] + + # disable mask, it does not affect performance + del mask + for i in range(self.num_feature_levels): + size_list.append(x[i].shape[-2:]) + pos.append(self.pe_layer(x[i], None).flatten(2)) + src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None]) + + # flatten NxCxHxW to HWxNxC + pos[-1] = pos[-1].permute(2, 0, 1) + src[-1] = src[-1].permute(2, 0, 1) + + _, bs, _ = src[0].shape + + # QxNxC + query_embed_ = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1) + query_feat = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1) + caping_lang_token = extra['start_token'].repeat(bs, 1) + start_id = 0 + if 'token' in extra: + caping_lang_token[:,:len(extra['token'][0])] = extra['token'] + start_id = len(extra['token'][0])-1 + query_feat_caping = self.query_feat_caping.weight.unsqueeze(1).repeat(1, bs, 1) + # pos_embed_caping = self.pos_embed_caping.weight.unsqueeze(1).repeat(1, bs, 1) + # prepare token embedding for evaluation + token_embs = self.lang_encoder.lang_encoder.token_embedding.weight + # token_embs = (token_embs / token_embs.norm(dim=-1, keepdim=True) + 1e-7) + + for cap_idx in range(start_id, self.captioning_step): + caping_lang_embed = self.lang_encoder.forward_language_token((caping_lang_token,))[0].transpose(0, 1) + # output = torch.cat((query_feat, caping_lang_embed), dim=0) # concat object query, class token and caption token. + # caping_lang_embed += pos_embed_caping + query_embed = torch.cat((query_embed_, caping_lang_embed), dim=0) # may not add at the beginning. + output = torch.cat((query_feat, query_feat_caping), dim=0) # concat object query, class token and caption token. + + # prediction heads on learnable query features + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[0], task=task) + attn_mask = results["attn_mask"] + + for i in range(self.num_layers): + level_index = i % self.num_feature_levels + attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False + attn_mask = torch.cat((attn_mask, torch.zeros_like(attn_mask[:, :self.contxt_len, :])), dim=1) + self_tgt_mask = self.self_attn_mask.repeat(output.shape[1]*self.num_heads, 1, 1) + + if extra['captioning_mask'] is not None: + bs,nq,wh = attn_mask.shape + assert bs==self.num_heads, "Only support single image referring captioning." + cap_mask = extra['captioning_mask'] + attn_mask = attn_mask.reshape(bs,nq,size_list[i%3][0],size_list[i%3][1]) + cap_mask = F.interpolate(cap_mask[None,].float(), size_list[i%3], mode='nearest').bool()[0,0] + attn_mask[:,self.num_queries:, cap_mask] = True + attn_mask = attn_mask.reshape(bs,nq,wh) + + # attention: cross-attention first + output, avg_attn = self.transformer_cross_attention_layers[i]( + output, src[level_index], + memory_mask=attn_mask, + memory_key_padding_mask=None, # here we do not apply masking on padded region + pos=pos[level_index], query_pos=query_embed + ) + + output = self.transformer_self_attention_layers[i]( + output, tgt_mask=self_tgt_mask, + tgt_key_padding_mask=None, + query_pos=query_embed + ) + + # FFN + output = self.transformer_ffn_layers[i]( + output + ) + + results = self.forward_prediction_heads(output, mask_features, attn_mask_target_size=size_list[(i + 1) % self.num_feature_levels], layer_id=i, task=task) + attn_mask = results["attn_mask"] + + pred_captions_gen = results['outputs_captionting'] + # pred_captions_gen = (pred_captions_gen / pred_captions_gen.norm(dim=-1, keepdim=True) + 1e-7) + pred_captions_gen = pred_captions_gen @ token_embs.t() + caping_lang_token[:,cap_idx+1] = pred_captions_gen[:,cap_idx].max(-1)[1] + + texts = self.lang_encoder.tokenizer.batch_decode(caping_lang_token, skip_special_tokens=False) + texts_new = [] + + for x in texts: + x = x.split('<|endoftext|>')[0] + x = x.replace('<|endoftext|>','') + x = x.replace('<|startoftext|>','') + x = x.strip() + texts_new.append(x) + + out = {'pred_captionings': caping_lang_token, + 'pred_texts': texts_new} + return out + + + def forward_prediction_heads(self, output, mask_features, attn_mask_target_size, layer_id=-1, task='seg'): + decoder_output = self.decoder_norm(output) + decoder_output = decoder_output.transpose(0, 1) + + # extract image captioning token from decoder output. + if self.task_switch['captioning'] and (task == 'vlp' or task == 'captioning_infer'): + outputs_captionting = decoder_output[:,self.num_queries:] @ self.caping_embed + else: + outputs_captionting = None + + # recompute class token output. + norm_decoder_output = decoder_output / (decoder_output.norm(dim=-1, keepdim=True) + 1e-7) + obj_token = norm_decoder_output[:,:self.num_queries-1] + cls_token = norm_decoder_output[:,self.num_queries-1:self.num_queries] + + sim = (cls_token @ obj_token.transpose(1,2)).softmax(-1)[:,0,:,None] # TODO include class token. + cls_token = (sim * decoder_output[:,:self.num_queries-1]).sum(dim=1, keepdim=True) + + if (((self.training and task == 'seg') or (task == 'grounding_eval')) and self.task_switch['grounding']) \ + or ((self.training and task == 'openimage') and self.task_switch['openimage']['grounding']): + decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token, decoder_output[:,self.num_queries:2*self.num_queries-1]), dim=1) + else: + decoder_output = torch.cat((decoder_output[:,:self.num_queries-1], cls_token), dim=1) + + # compute class, mask and bbox. + class_embed = decoder_output @ self.class_embed + # HACK do not compute similarity if mask is not on + outputs_class = self.lang_encoder.compute_similarity(class_embed, fake=(((not self.task_switch['mask']) and self.training) or (task == 'openimage'))) + + if self.task_switch['mask'] or self.task_switch['openimage']['mask']: + mask_embed = self.mask_embed(decoder_output) + outputs_mask = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) + + # NOTE: prediction is of higher-resolution + # [B, Q, H, W] -> [B, Q, H*W] -> [B, h, Q, H*W] -> [B*h, Q, HW] + attn_mask = F.interpolate(outputs_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) + + # must use bool type + # If a BoolTensor is provided, positions with ``True`` are not allowed to attend while ``False`` values will be unchanged. + attn_mask = (attn_mask.sigmoid().flatten(2).unsqueeze(1).repeat(1, self.num_heads, 1, 1).flatten(0, 1) < 0.5).bool() + attn_mask = attn_mask.detach() + + # NOTE: fill False for cls token (JY) + attn_mask[:, self.num_queries:self.num_queries+1].fill_(False) + else: + outputs_mask = None + attn_mask = torch.zeros((list(decoder_output.shape[:2]) + [attn_mask_target_size[0]*attn_mask_target_size[1]]), device=decoder_output.device).repeat(self.num_heads, 1, 1).bool() + + outputs_bbox = [None for i in range(len(decoder_output))] + if self.task_switch['bbox']: + outputs_bbox = self.bbox_embed(decoder_output) + + outputs_caption = None + if self.task_switch['caption']: + outputs_caption = class_embed + + + results = { + "outputs_class": outputs_class, + "outputs_mask": outputs_mask, + "outputs_bbox": outputs_bbox, + "attn_mask": attn_mask, + "outputs_caption": outputs_caption, + "outputs_captionting": outputs_captionting, + } + return results + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_seg_masks, outputs_boxes, outputs_captions): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + if self.mask_classification: + return [ + {"pred_logits": a, "pred_masks": b, "pred_boxes": c, "pred_captions": d} + for a, b, c, d in zip(outputs_class[:-1], outputs_seg_masks[:-1], outputs_boxes[:-1], outputs_captions[:-1]) + ] + else: + return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] + + +@register_decoder +def get_masked_transformer_decoder(cfg, in_channels, lang_encoder, mask_classification, extra): + return MultiScaleMaskedTransformerDecoder(cfg, in_channels, lang_encoder, mask_classification, extra) \ No newline at end of file diff --git a/xdecoder/body/encoder/__init__.py b/xdecoder/body/encoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..bf9bb57ca080f4e2f1d1edd7c696285a08faa706 --- /dev/null +++ b/xdecoder/body/encoder/__init__.py @@ -0,0 +1 @@ +from .build import build_encoder \ No newline at end of file diff --git a/xdecoder/body/encoder/__pycache__/__init__.cpython-38.pyc b/xdecoder/body/encoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c9acb31c4285cb65ca77953309b6ba3fe6ff6af Binary files /dev/null and b/xdecoder/body/encoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/body/encoder/__pycache__/build.cpython-38.pyc b/xdecoder/body/encoder/__pycache__/build.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbfb049be90b841127341e368e54da86e24e2097 Binary files /dev/null and b/xdecoder/body/encoder/__pycache__/build.cpython-38.pyc differ diff --git a/xdecoder/body/encoder/__pycache__/registry.cpython-38.pyc b/xdecoder/body/encoder/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3923809f09ee13d7d8f0cd4ca2ae79a6b315240 Binary files /dev/null and b/xdecoder/body/encoder/__pycache__/registry.cpython-38.pyc differ diff --git a/xdecoder/body/encoder/__pycache__/transformer_encoder_fpn.cpython-38.pyc b/xdecoder/body/encoder/__pycache__/transformer_encoder_fpn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cb482ca31468e0d7dd5b7f19e936585af6d69c6 Binary files /dev/null and b/xdecoder/body/encoder/__pycache__/transformer_encoder_fpn.cpython-38.pyc differ diff --git a/xdecoder/body/encoder/build.py b/xdecoder/body/encoder/build.py new file mode 100755 index 0000000000000000000000000000000000000000..aabf8bca5c6f54144af3187692afc28de4c9e296 --- /dev/null +++ b/xdecoder/body/encoder/build.py @@ -0,0 +1,12 @@ +from .registry import model_entrypoints +from .registry import is_model + +from .transformer_encoder_fpn import * + +def build_encoder(config, *args, **kwargs): + model_name = config['MODEL']['ENCODER']['NAME'] + + if not is_model(model_name): + raise ValueError(f'Unkown model: {model_name}') + + return model_entrypoints(model_name)(config, *args, **kwargs) \ No newline at end of file diff --git a/xdecoder/body/encoder/registry.py b/xdecoder/body/encoder/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..99426a4495cf65e7ce82193f711aaa225b6d2395 --- /dev/null +++ b/xdecoder/body/encoder/registry.py @@ -0,0 +1,13 @@ +_model_entrypoints = {} + +def register_encoder(fn): + module_name_split = fn.__module__.split('.') + model_name = module_name_split[-1] + _model_entrypoints[model_name] = fn + return fn + +def model_entrypoints(model_name): + return _model_entrypoints[model_name] + +def is_model(model_name): + return model_name in _model_entrypoints diff --git a/xdecoder/body/encoder/transformer_encoder_fpn.py b/xdecoder/body/encoder/transformer_encoder_fpn.py new file mode 100755 index 0000000000000000000000000000000000000000..16e449fd3ac19a5d143d4fc61cbafc16158b0654 --- /dev/null +++ b/xdecoder/body/encoder/transformer_encoder_fpn.py @@ -0,0 +1,324 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ +from torch.cuda.amp import autocast + +import fvcore.nn.weight_init as weight_init +from detectron2.layers import Conv2d, DeformConv, ShapeSpec, get_norm + +from .registry import register_encoder +from ..transformer_blocks import TransformerEncoder, TransformerEncoderLayer, _get_clones, _get_activation_fn +from ...modules import PositionEmbeddingSine +from ...utils import configurable + +# from ..layers import Conv2d, DeformConv, ShapeSpec, get_norm + +# This is a modified FPN decoder. +class BasePixelDecoder(nn.Module): + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + conv_dim: int, + mask_dim: int, + mask_on: bool, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__() + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_channels = [v.channels for k, v in input_shape] + + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(feature_channels): + if idx == len(self.in_features) - 1: + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + in_channels, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(None) + output_convs.append(output_conv) + else: + lateral_norm = get_norm(norm, conv_dim) + output_norm = get_norm(norm, conv_dim) + + lateral_conv = Conv2d( + in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm + ) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(lateral_conv) + weight_init.c2_xavier_fill(output_conv) + self.add_module("adapter_{}".format(idx + 1), lateral_conv) + self.add_module("layer_{}".format(idx + 1), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + + self.mask_on = mask_on + if self.mask_on: + self.mask_dim = mask_dim + self.mask_features = Conv2d( + conv_dim, + mask_dim, + kernel_size=3, + stride=1, + padding=1, + ) + weight_init.c2_xavier_fill(self.mask_features) + + self.maskformer_num_feature_levels = 3 # always use 3 scales + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + enc_cfg = cfg['MODEL']['ENCODER'] + ret = {} + ret["input_shape"] = { + k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] + } + ret["conv_dim"] = enc_cfg['CONVS_DIM'] + ret["mask_dim"] = enc_cfg['MASK_DIM'] + ret["norm"] = enc_cfg['NORM'] + return ret + + def forward_features(self, features): + multi_scale_features = [] + num_cur_levels = 0 + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + y = output_conv(x) + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + if num_cur_levels < self.maskformer_num_feature_levels: + multi_scale_features.append(y) + num_cur_levels += 1 + + mask_features = self.mask_features(y) if self.mask_on else None + return mask_features, None, multi_scale_features + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") + return self.forward_features(features) + + +class TransformerEncoderOnly(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + if mask is not None: + mask = mask.flatten(1) + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + return memory.permute(1, 2, 0).view(bs, c, h, w) + + +# This is a modified FPN decoder with extra Transformer encoder that processes the lowest-resolution feature map. +class TransformerEncoderPixelDecoder(BasePixelDecoder): + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + transformer_dropout: float, + transformer_nheads: int, + transformer_dim_feedforward: int, + transformer_enc_layers: int, + transformer_pre_norm: bool, + conv_dim: int, + mask_dim: int, + mask_on: int, + norm: Optional[Union[str, Callable]] = None, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + transformer_dropout: dropout probability in transformer + transformer_nheads: number of heads in transformer + transformer_dim_feedforward: dimension of feedforward network + transformer_enc_layers: number of transformer encoder layers + transformer_pre_norm: whether to use pre-layernorm or not + conv_dims: number of output channels for the intermediate conv layers. + mask_dim: number of output channels for the final conv layer. + norm (str or callable): normalization for all conv layers + """ + super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm, mask_on=mask_on) + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + in_channels = feature_channels[len(self.in_features) - 1] + self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1) + weight_init.c2_xavier_fill(self.input_proj) + self.transformer = TransformerEncoderOnly( + d_model=conv_dim, + dropout=transformer_dropout, + nhead=transformer_nheads, + dim_feedforward=transformer_dim_feedforward, + num_encoder_layers=transformer_enc_layers, + normalize_before=transformer_pre_norm, + ) + N_steps = conv_dim // 2 + self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) + + # update layer + use_bias = norm == "" + output_norm = get_norm(norm, conv_dim) + output_conv = Conv2d( + conv_dim, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + activation=F.relu, + ) + weight_init.c2_xavier_fill(output_conv) + delattr(self, "layer_{}".format(len(self.in_features))) + self.add_module("layer_{}".format(len(self.in_features)), output_conv) + self.output_convs[0] = output_conv + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + enc_cfg = cfg['MODEL']['ENCODER'] + dec_cfg = cfg['MODEL']['DECODER'] + + ret = super().from_config(cfg, input_shape) + ret["transformer_dropout"] = dec_cfg['DROPOUT'] + ret["transformer_nheads"] = dec_cfg['NHEADS'] + ret["transformer_dim_feedforward"] = dec_cfg['DIM_FEEDFORWARD'] + ret["transformer_enc_layers"] = enc_cfg['TRANSFORMER_ENC_LAYERS'] # a separate config + ret["transformer_pre_norm"] = dec_cfg['PRE_NORM'] + + ret['mask_on'] = cfg['MODEL']['DECODER']['MASK'] + return ret + + def forward_features(self, features): + multi_scale_features = [] + num_cur_levels = 0 + + # Reverse feature maps into top-down order (from low to high resolution) + for idx, f in enumerate(self.in_features[::-1]): + x = features[f] + lateral_conv = self.lateral_convs[idx] + output_conv = self.output_convs[idx] + if lateral_conv is None: + transformer = self.input_proj(x) + pos = self.pe_layer(x) + transformer = self.transformer(transformer, None, pos) + y = output_conv(transformer) + # save intermediate feature as input to Transformer decoder + transformer_encoder_features = transformer + else: + cur_fpn = lateral_conv(x) + # Following FPN implementation, we use nearest upsampling here + y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") + y = output_conv(y) + if num_cur_levels < self.maskformer_num_feature_levels: + multi_scale_features.append(y) + num_cur_levels += 1 + + mask_features = self.mask_features(y) if self.mask_on else None + return mask_features, transformer_encoder_features, multi_scale_features + + def forward(self, features, targets=None): + logger = logging.getLogger(__name__) + logger.warning("Calling forward() may cause unpredicted behavior of PixelDecoder module.") + return self.forward_features(features) + + + +@register_encoder +def get_transformer_encoder_fpn(cfg, input_shape): + """ + Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`. + """ + model = TransformerEncoderPixelDecoder(cfg, input_shape) + forward_features = getattr(model, "forward_features", None) + if not callable(forward_features): + raise ValueError( + "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. " + f"Please implement forward_features for {name} to only return mask features." + ) + return model \ No newline at end of file diff --git a/xdecoder/body/registry.py b/xdecoder/body/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..0200b0af6cd9e01451be4df9f713719f45f2e928 --- /dev/null +++ b/xdecoder/body/registry.py @@ -0,0 +1,14 @@ +_model_entrypoints = {} + + +def register_body(fn): + module_name_split = fn.__module__.split('.') + model_name = module_name_split[-1] + _model_entrypoints[model_name] = fn + return fn + +def model_entrypoints(model_name): + return _model_entrypoints[model_name] + +def is_model(model_name): + return model_name in _model_entrypoints \ No newline at end of file diff --git a/xdecoder/body/transformer_blocks.py b/xdecoder/body/transformer_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..54134f34556b32c98401be2eb862e539ccb812d4 --- /dev/null +++ b/xdecoder/body/transformer_blocks.py @@ -0,0 +1,370 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py +""" +Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Transformer(nn.Module): + def __init__( + self, + d_model=512, + nhead=8, + num_encoder_layers=6, + num_decoder_layers=6, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + return_intermediate_dec=False, + ): + super().__init__() + + encoder_layer = TransformerEncoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer( + d_model, nhead, dim_feedforward, dropout, activation, normalize_before + ) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder( + decoder_layer, + num_decoder_layers, + decoder_norm, + return_intermediate=return_intermediate_dec, + ) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + if mask is not None: + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder( + tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed + ) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") diff --git a/xdecoder/body/xdecoder_head.py b/xdecoder/body/xdecoder_head.py new file mode 100755 index 0000000000000000000000000000000000000000..b04af973501c2c361de2b4a3a78ebbab1ae44b8a --- /dev/null +++ b/xdecoder/body/xdecoder_head.py @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# -------------------------------------------------------- +# X-Decoder -- Generalized Decoding for Pixel, Image, and Language +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Jianwei Yang (jianwyan@microsoft.com), Xueyan Zou (xueyan@cs.wisc.edu) +# -------------------------------------------------------- + +from typing import Dict + +from torch import nn + +from detectron2.layers import ShapeSpec + +from .registry import register_body +from .encoder import build_encoder +from .decoder import build_decoder +from ..utils import configurable + + +class XDecoderHead(nn.Module): + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + pixel_decoder: nn.Module, + loss_weight: float = 1.0, + ignore_value: int = -1, + # extra parameters + transformer_predictor: nn.Module, + transformer_in_feature: str, + ): + """ + NOTE: this interface is experimental. + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + pixel_decoder: the pixel decoder module + loss_weight: loss weight + ignore_value: category id to be ignored during training. + transformer_predictor: the transformer decoder that makes prediction + transformer_in_feature: input feature name to the transformer_predictor + """ + super().__init__() + + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = 4 + self.loss_weight = loss_weight + + self.pixel_decoder = pixel_decoder + self.predictor = transformer_predictor + self.transformer_in_feature = transformer_in_feature + + self.num_classes = num_classes + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec], lang_encoder: nn.Module, extra: dict): + + in_features_type = cfg['MODEL']['DECODER']['TRANSFORMER_IN_FEATURE'] + enc_cfg = cfg['MODEL']['ENCODER'] + dec_cfg = cfg['MODEL']['DECODER'] + + # figure out in_channels to transformer predictor + if in_features_type == "transformer_encoder": + transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] + elif in_features_type == "pixel_embedding": + transformer_predictor_in_channels = enc_cfg['MASK_DIM'] + elif in_features_type == "multi_scale_pixel_decoder": # for maskformer2 + transformer_predictor_in_channels = enc_cfg['CONVS_DIM'] + else: + transformer_predictor_in_channels = input_shape[dec_cfg['TRANSFORMER_IN_FEATURE']].channels + + return { + "input_shape": { + k: v for k, v in input_shape.items() if k in enc_cfg['IN_FEATURES'] + }, + "ignore_value": enc_cfg['IGNORE_VALUE'], + "num_classes": enc_cfg.get('NUM_CLASSES', None), + "pixel_decoder": build_encoder(cfg, input_shape), + "loss_weight": enc_cfg['LOSS_WEIGHT'], + "transformer_in_feature": dec_cfg['TRANSFORMER_IN_FEATURE'], + "transformer_predictor": build_decoder( + cfg, + transformer_predictor_in_channels, + lang_encoder, + mask_classification=True, + extra=extra, + ), + } + + def forward(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): + return self.layers(features, mask, target_queries, target_vlp, task, extra) + + def layers(self, features, mask=None, target_queries=None, target_vlp=None, task='seg', extra={}): + mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(features) + + if self.transformer_in_feature == "multi_scale_pixel_decoder": + predictions = self.predictor(multi_scale_features, mask_features, mask, target_queries, target_vlp, task, extra) + else: + if self.transformer_in_feature == "transformer_encoder": + assert ( + transformer_encoder_features is not None + ), "Please use the TransformerEncoderPixelDecoder." + predictions = self.predictor(transformer_encoder_features, mask_features, mask) + elif self.transformer_in_feature == "pixel_embedding": + predictions = self.predictor(mask_features, mask_features, mask) + else: + predictions = self.predictor(features[self.transformer_in_feature], mask_features, mask) + return predictions + + +@register_body +def get_xdecoder_head(cfg, input_shape, lang_encoder, extra): + return XDecoderHead(cfg, input_shape, lang_encoder, extra) \ No newline at end of file diff --git a/xdecoder/language/LangEncoder/__init__.py b/xdecoder/language/LangEncoder/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..ebc0a5d2e6bc4a4a93935450838acf09455004f6 --- /dev/null +++ b/xdecoder/language/LangEncoder/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from .build import build_lang_encoder +from .build import build_tokenizer + +from .transformer import * \ No newline at end of file diff --git a/xdecoder/language/LangEncoder/__pycache__/__init__.cpython-38.pyc b/xdecoder/language/LangEncoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c45d9e0c0b58321a4899fca74e9af0bbd11d405 Binary files /dev/null and b/xdecoder/language/LangEncoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/language/LangEncoder/__pycache__/build.cpython-38.pyc b/xdecoder/language/LangEncoder/__pycache__/build.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f73626d5b666bb6ce2d846ceebf264071e1b1fdb Binary files /dev/null and b/xdecoder/language/LangEncoder/__pycache__/build.cpython-38.pyc differ diff --git a/xdecoder/language/LangEncoder/__pycache__/registry.cpython-38.pyc b/xdecoder/language/LangEncoder/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8104013a2f476ad66c7a3a5d2e084e6c9663ce37 Binary files /dev/null and b/xdecoder/language/LangEncoder/__pycache__/registry.cpython-38.pyc differ diff --git a/xdecoder/language/LangEncoder/__pycache__/transformer.cpython-38.pyc b/xdecoder/language/LangEncoder/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ba951176c5836dda22a4ec97233b04af7b66432 Binary files /dev/null and b/xdecoder/language/LangEncoder/__pycache__/transformer.cpython-38.pyc differ diff --git a/xdecoder/language/LangEncoder/build.py b/xdecoder/language/LangEncoder/build.py new file mode 100755 index 0000000000000000000000000000000000000000..87a39af5e17ad08f583fc294716491fb87469287 --- /dev/null +++ b/xdecoder/language/LangEncoder/build.py @@ -0,0 +1,36 @@ +import os + +from transformers import CLIPTokenizer, CLIPTokenizerFast +from transformers import AutoTokenizer + +from .registry import lang_encoders +from .registry import is_lang_encoder + + +def build_lang_encoder(config_encoder, tokenizer, verbose, **kwargs): + model_name = config_encoder['NAME'] + + if not is_lang_encoder(model_name): + raise ValueError(f'Unkown model: {model_name}') + + return lang_encoders(model_name)(config_encoder, tokenizer, verbose, **kwargs) + + +def build_tokenizer(config_encoder): + tokenizer = None + os.environ['TOKENIZERS_PARALLELISM'] = 'true' + if config_encoder['TOKENIZER'] == 'clip': + pretrained_tokenizer = config_encoder.get( + 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' + ) + tokenizer = CLIPTokenizer.from_pretrained(pretrained_tokenizer) + tokenizer.add_special_tokens({'cls_token': tokenizer.eos_token}) + elif config_encoder['TOKENIZER'] == 'clip-fast': + pretrained_tokenizer = config_encoder.get( + 'PRETRAINED_TOKENIZER', 'openai/clip-vit-base-patch32' + ) + tokenizer = CLIPTokenizerFast.from_pretrained(pretrained_tokenizer, from_slow=True) + else: + tokenizer = AutoTokenizer.from_pretrained(config_encoder['TOKENIZER']) + + return tokenizer diff --git a/xdecoder/language/LangEncoder/registry.py b/xdecoder/language/LangEncoder/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..8991272a6e2294ea86eee338cf61d87e4123f724 --- /dev/null +++ b/xdecoder/language/LangEncoder/registry.py @@ -0,0 +1,18 @@ +_lang_encoders = {} + + +def register_lang_encoder(fn): + module_name_split = fn.__module__.split('.') + model_name = module_name_split[-1] + + _lang_encoders[model_name] = fn + + return fn + + +def lang_encoders(model_name): + return _lang_encoders[model_name] + + +def is_lang_encoder(model_name): + return model_name in _lang_encoders diff --git a/xdecoder/language/LangEncoder/transformer.py b/xdecoder/language/LangEncoder/transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..00123460f0aa93801bdf750af62e3a14753c0366 --- /dev/null +++ b/xdecoder/language/LangEncoder/transformer.py @@ -0,0 +1,222 @@ +from collections import OrderedDict +from typing import Tuple, Union +import logging +import os + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from timm.models.layers import DropPath, trunc_normal_ + +from .registry import register_lang_encoder +from utils.distributed import is_main_process +from utils.model import register_norm_module + +logger = logging.getLogger(__name__) + + +@register_norm_module +class LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(LayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + pdtype = x.dtype + x = x.float() + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x.to(pdtype) + self.bias + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None, + drop_path: float = 0.0): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ + if self.attn_mask is not None else None + + + return self.attn( + x, x, x, + key_padding_mask=key_padding_mask, + need_weights=False, + attn_mask=self.attn_mask + )[0] + + def forward(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): + x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + def __init__(self, + context_length: int, + vocab_size: int, + width: int, + layers: int, + heads: int, + drop_path: float = 0.0, + autogressive: bool =True): + super().__init__() + + self.token_embedding = nn.Embedding(vocab_size, width) + + self.context_length = context_length + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, width) + ) + + self.width = width + self.layers = layers + self.autogressive = autogressive + attn_mask = self.build_attention_mask() if autogressive else None + dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] # stochastic depth decay rule + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(width, heads, attn_mask, dpr[i]) + for i in range(layers) + ] + ) + + self.ln_final = LayerNorm(width) + + trunc_normal_(self.positional_embedding, std=.02) + # nn.init.normal_(self.token_embedding, std=.02) + trunc_normal_(self.token_embedding.weight, std=.02) + self.apply(self._init_weights) + + @property + def dim_out(self): + return self.width + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def _init_weights(self, m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + if is_main_process(): + logger.info('=> init weight of Linear/Conv2d from trunc norm') + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + if is_main_process(): + logger.info('=> init bias of Linear/Conv2d to zeros') + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): + nn.init.constant_(m.bias, 0) + + def load_pretrained(self, pretrained='', pretrained_layers=[], verbose=True): + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained, map_location='cpu') + logging.info(f'=> loading pretrained model {pretrained}') + model_dict = self.state_dict() + stripped_key = lambda x: x[13:] if x.startswith('lang_encoder.') else x + pretrained_dict = { + stripped_key(k): v for k, v in pretrained_dict.items() + if stripped_key(k) in model_dict.keys() + } + need_init_state_dict = {} + for k, v in pretrained_dict.items(): + need_init = ( + k.split('.')[0] in pretrained_layers + or pretrained_layers[0] == '*' + ) + if need_init: + if verbose: + logger.info(f'=> init {k} from {pretrained}') + + if 'positional_embedding' in k and v.size() != model_dict[k].size(): + positional_embedding_pretrained = v + positional_embedding_current = model_dict[k] + L1, nH1 = positional_embedding_pretrained.size() + L2, nH2 = positional_embedding_current.size() + if nH1 != nH2: + logger.info(f"Error in loading {k}, passing") + else: + if L1 != L2: + logger.info( + '=> load_pretrained: resized variant: {} to {}' + .format((L1, nH1), (L2, nH2)) + ) + + posemb = positional_embedding_pretrained.float() + posemb_grid = posemb.unsqueeze(dim=0).permute(0, 2, 1) + posemb_grid = torch.nn.functional.interpolate(posemb_grid, size=L2, mode='linear') + posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(dim=0) + v = posemb_grid + + need_init_state_dict[k] = v + + self.load_state_dict(need_init_state_dict, strict=False) + + + @torch.jit.ignore + def no_weight_decay(self): + return { + 'positional_embedding', + 'token_embedding', + } + + def forward(self, input_ids, attention_mask=None): + key_padding_mask = (attention_mask == 0) if (not self.autogressive and attention_mask is not None) else None + # key_padding_mask = (input_ids == 0) if not self.autogressive else None + x = self.token_embedding(input_ids) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + for block in self.resblocks: + x = block(x, key_padding_mask) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_final(x) + + return {'last_hidden_state': x} + + +@register_lang_encoder +def lang_encoder(config_encoder, tokenizer, verbose, **kwargs): + transformer = Transformer( + context_length=config_encoder['CONTEXT_LENGTH'], + vocab_size=tokenizer.vocab_size, + width=config_encoder['WIDTH'], + layers=config_encoder['LAYERS'], + heads=config_encoder['HEADS'], + autogressive=config_encoder.get('AUTOGRESSIVE', True) + ) + + if config_encoder.get('LOAD_PRETRAINED', False): + transformer.load_pretrained(config_encoder['PRETRAINED'], config_encoder.get('PRETRAINED_LAYERS', ['*'])) + return transformer diff --git a/xdecoder/language/__init__.py b/xdecoder/language/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..4118dc74282568a13fab564428a19a7b1c30b414 --- /dev/null +++ b/xdecoder/language/__init__.py @@ -0,0 +1,3 @@ +from .fixvlpencoder import * +from .vlpencoder import * +from .build import build_language_encoder \ No newline at end of file diff --git a/xdecoder/language/__pycache__/__init__.cpython-38.pyc b/xdecoder/language/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..912cfb2bd643a32a5eb6b3d0301b17d865161e90 Binary files /dev/null and b/xdecoder/language/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/language/__pycache__/build.cpython-38.pyc b/xdecoder/language/__pycache__/build.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cba74dd1081bd2bbac0945847421c8d261f3d05a Binary files /dev/null and b/xdecoder/language/__pycache__/build.cpython-38.pyc differ diff --git a/xdecoder/language/__pycache__/fixvlpencoder.cpython-38.pyc b/xdecoder/language/__pycache__/fixvlpencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c07b0c1d96bd949258c5fcf890b8e00cd65fb97 Binary files /dev/null and b/xdecoder/language/__pycache__/fixvlpencoder.cpython-38.pyc differ diff --git a/xdecoder/language/__pycache__/loss.cpython-38.pyc b/xdecoder/language/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d7dac2b007f5e1e16ffcfd3ffc86e6bdc8a0c8 Binary files /dev/null and b/xdecoder/language/__pycache__/loss.cpython-38.pyc differ diff --git a/xdecoder/language/__pycache__/registry.cpython-38.pyc b/xdecoder/language/__pycache__/registry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bb421f17c7fffdb65a4f09ace581979d411c5dc Binary files /dev/null and b/xdecoder/language/__pycache__/registry.cpython-38.pyc differ diff --git a/xdecoder/language/__pycache__/vlpencoder.cpython-38.pyc b/xdecoder/language/__pycache__/vlpencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fe659df3c65cbe6c040485c53d4cd67bcd55dfa Binary files /dev/null and b/xdecoder/language/__pycache__/vlpencoder.cpython-38.pyc differ diff --git a/xdecoder/language/build.py b/xdecoder/language/build.py new file mode 100755 index 0000000000000000000000000000000000000000..8d9acdf9766e3bc1184c4200ef4dace3437617e4 --- /dev/null +++ b/xdecoder/language/build.py @@ -0,0 +1,11 @@ +from .registry import model_entrypoints +from .registry import is_model + + +def build_language_encoder(config, **kwargs): + model_name = config['MODEL']['TEXT']['ARCH'] + + if not is_model(model_name): + raise ValueError(f'Unkown model: {model_name}') + + return model_entrypoints(model_name)(config, **kwargs) \ No newline at end of file diff --git a/xdecoder/language/fixvlpencoder.py b/xdecoder/language/fixvlpencoder.py new file mode 100755 index 0000000000000000000000000000000000000000..dd91faf136b4e479dba03cc81b21ed5f3b47e1e0 --- /dev/null +++ b/xdecoder/language/fixvlpencoder.py @@ -0,0 +1,35 @@ +from importlib.metadata import requires +import torch +import torch.nn as nn + +from .registry import register_model +from .vlpencoder import LanguageEncoder + +class FixLanguageEncoder(LanguageEncoder): + + def __init__( + self, + *args, **kwargs): + super(FixLanguageEncoder, self).__init__(*args, **kwargs) + self.logit_scale = nn.Parameter(torch.ones([]), requires_grad=False) + + @torch.no_grad() + def get_text_embeddings(self, *args, **kwargs): + return super().get_text_embeddings(*args, **kwargs) + + @torch.no_grad() + def get_text_token_embeddings(self, *args, **kwargs): + return super().get_text_token_embeddings(*args, **kwargs) + + @torch.no_grad() + def forward_language(self, *args, **kwargs): + return super().forward_language(*args, **kwargs) + + @torch.no_grad() + def forward_language_token(self, *args, **kwargs): + return super().forward_language_token(*args, **kwargs) + + +@register_model +def get_language_model(cfg, **kwargs): + return FixLanguageEncoder(cfg) \ No newline at end of file diff --git a/xdecoder/language/loss.py b/xdecoder/language/loss.py new file mode 100755 index 0000000000000000000000000000000000000000..fe7ecd566bbf7f7e5a9981c7789c16c537ecb6b5 --- /dev/null +++ b/xdecoder/language/loss.py @@ -0,0 +1,225 @@ +import pickle +from distutils import log + +import torch +import torch.nn.functional as F +import torch.distributed as dist + +from einops import rearrange, repeat +from timm.loss import SoftTargetCrossEntropy + +soft_cross_entropy = SoftTargetCrossEntropy() + +def is_dist_initialized(): + return torch.distributed.is_initialized() + +def get_world_size(): + if is_dist_initialized(): + return torch.distributed.get_world_size() + return 1 + +def get_rank(): + if is_dist_initialized(): + return dist.get_rank() + return 0 + +def all_gather_grad(x): + if get_world_size() > 1: + all_x = [torch.zeros_like(x) for _ in range(get_world_size())] + torch.distributed.all_gather(all_x, x) + all_x[torch.distributed.get_rank()] = x + x = torch.cat(all_x, dim=0) + return x + +def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1): + """ + Args: + image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256 + text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256 + + Returns: + """ + # [B, L1, C], L1 = 1 + # image_feat = F.normalize(image_feat, dim=-1) + # [B, L2, C] + # text_feat = F.normalize(text_feat, dim=-1) + # HACK: normalize outside + + # [B, L1, L2] + dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l') + # [B, L2, L1] + dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l') + + batch = image_feat.shape[0] + img_len = image_feat.shape[1] + text_len = text_feat.shape[1] + # [B, L1, L2] + pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2') + # [B, L2, L1] + pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1') + + image_x = rearrange(image_feat, 'b l c -> (b l) c') + text_x = rearrange(text_feat, 'b l c -> (b l) c') + + logits_per_img = image_x @ all_gather_grad(text_x).t() + logits_per_text = text_x @ all_gather_grad(image_x).t() + + # get label globally + # [B, L1, B, L2, W] + labels_per_img = F.one_hot( + torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(), + num_classes=get_world_size()).to(image_x.dtype) + labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat( + torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1') + # [BxL1, WxBxL2] + labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)') + # [B, L2, B, L1, W] + labels_per_text = F.one_hot( + torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(), + num_classes=get_world_size()).to(text_x.dtype) + labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat( + torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1') + # [BxL2, WxBxL1] + labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)') + + logit_scale = temperature.exp().clamp(max=100) + + loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img) + loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text) + + loss = 0.5 * (loss_img + loss_text) + return loss + +def vl_contrastive_loss(image_feat, text_feat, temperature=1): + # if image_id or text_id is None, it should be None across all GPUs + # image_feat = F.normalize(image_feat, dim=1) + # text_feat = F.normalize(text_feat, dim=1) + # handle normalization outside + + # add the following 4 lines + image_feat = all_gather_grad(image_feat) + text_feat = all_gather_grad(text_feat) + + logits = torch.matmul(image_feat, text_feat.t()) + logit_scale = temperature.exp().clamp(max=100) + + gt = torch.arange(logits.shape[0], device=logits.device) + loss1 = F.cross_entropy(logit_scale * logits, gt) + loss2 = F.cross_entropy(logit_scale * logits.t(), gt) + return (loss1 + loss2) / 2 # scale it up by the number of GPUs + + +def all_gather_pickle(data, device): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device) + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]).cuda() + size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda()) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).cuda() + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + +def all_gather_arbitary_tensor(tensor): + if get_world_size() > 1: + device = tensor.device + tensor_batch = all_gather_pickle(tensor.cpu(), device) + tensor_batch = [x.to(device) for x in tensor_batch] + tensor_batch[torch.distributed.get_rank()] = tensor + tensor_batch = torch.cat(tensor_batch, dim=0) + else: + tensor_batch = tensor + return tensor_batch + +def ql_contrastive_loss(image_feat, text_feat, temperature=1): + # add the following 4 lines + image_feat = all_gather_arbitary_tensor(image_feat) + text_feat = all_gather_arbitary_tensor(text_feat) + + logits = torch.matmul(image_feat, text_feat.t()) + logit_scale = temperature.exp().clamp(max=100) + + gt = torch.arange(logits.shape[0], device=logits.device) + loss1 = F.cross_entropy(logit_scale * logits, gt) + loss2 = F.cross_entropy(logit_scale * logits.t(), gt) + return (loss1 + loss2) / 2 # scale it up by the number of GPUs + +def vl_similarity(image_feat, text_feat, temperature=1): + # Only support single GPU for now. + logits = torch.matmul(image_feat, text_feat.t()) + logits = temperature.exp().clamp(max=100) * logits + return logits + +def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1): + # add the following 4 lines + image_feat = all_gather_arbitary_tensor(image_feat) + text_feat = all_gather_arbitary_tensor(text_feat) + + text_hash_batch = all_gather_pickle(text_hash, text_feat.device) + text_hash_all = torch.cat(text_hash_batch) + + text_hash_all_unique = torch.unique(text_hash_all).tolist() + gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device) + text_hash_all = text_hash_all.tolist() + text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique]) + + for idx, txt in enumerate(text_hash_all): + gt[idx][text_hash_all_unique.index(txt)] = 1 + + logits = torch.matmul(image_feat, text_feat_unique.t()) + logits = logits*temperature.exp().clamp(max=100) + + loss_img = soft_cross_entropy(logits, gt) + loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True)) + + loss = 0.7 * loss_img + 0.3 * loss_text + return loss + +def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training): + # add the following 4 lines + image_feat = all_gather_grad(image_feat_inp.contiguous()) + text_feat = all_gather_grad(text_feat_inp.contiguous()) + + image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7) + text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7) + + temperature = lang_enc.logit_scale + logits = torch.matmul(image_feat, text_feat.t()) + logit_scale = temperature.exp().clamp(max=100) + + gt = torch.arange(logits.shape[0], device=logits.device) + loss1 = F.cross_entropy(logit_scale * logits, gt) + loss2 = F.cross_entropy(logit_scale * logits.t(), gt) + + return (loss1 + loss2) / 2 # scale it up by the number of GPUs \ No newline at end of file diff --git a/xdecoder/language/misc.py b/xdecoder/language/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..faf172fbb8a90ed49ca0de9a9ca1d875f2f96215 --- /dev/null +++ b/xdecoder/language/misc.py @@ -0,0 +1,64 @@ +import random + +import nltk +nltk.data.path.append('/mnt/data/nltk_data') +import numpy as np + +from utils.constants import IMAGENET_DEFAULT_TEMPLATES + + +def get_tag(tokenized, tags): + if not isinstance(tags, (list, tuple)): + tags = [tags] + ret = [] + for (word, pos) in nltk.pos_tag(tokenized): + for tag in tags: + if pos == tag: + ret.append(word) + return ret + +def get_noun_phrase(tokenized): + # Taken from Su Nam Kim Paper... + grammar = r""" + NBAR: + {*} # Nouns and Adjectives, terminated with Nouns + + NP: + {} + {} # Above, connected with in/of/etc... + """ + chunker = nltk.RegexpParser(grammar) + + chunked = chunker.parse(nltk.pos_tag(tokenized)) + continuous_chunk = [] + current_chunk = [] + + for subtree in chunked: + if isinstance(subtree, nltk.Tree): + current_chunk.append(' '.join([token for token, pos in subtree.leaves()])) + elif current_chunk: + named_entity = ' '.join(current_chunk) + if named_entity not in continuous_chunk: + continuous_chunk.append(named_entity) + current_chunk = [] + else: + continue + + return continuous_chunk + +def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True): + tokenized = nltk.word_tokenize(text) + + if random.random() >= phrase_prob: + nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP']) + else: + nouns = get_noun_phrase(tokenized) + + + prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns] + + if append_text: + prompt_texts += [text] + nouns += [text] + + return prompt_texts, nouns \ No newline at end of file diff --git a/xdecoder/language/registry.py b/xdecoder/language/registry.py new file mode 100755 index 0000000000000000000000000000000000000000..940e4560f7d052aed4915187410266ab5a4cb4d0 --- /dev/null +++ b/xdecoder/language/registry.py @@ -0,0 +1,13 @@ +_model_entrypoints = {} + +def register_model(fn): + module_name_split = fn.__module__.split('.') + model_name = module_name_split[-1] + _model_entrypoints[model_name] = fn + return fn + +def model_entrypoints(model_name): + return _model_entrypoints[model_name] + +def is_model(model_name): + return model_name in _model_entrypoints \ No newline at end of file diff --git a/xdecoder/language/vlpencoder.py b/xdecoder/language/vlpencoder.py new file mode 100755 index 0000000000000000000000000000000000000000..ce6fd4709255e8869749d7401babb373b187d697 --- /dev/null +++ b/xdecoder/language/vlpencoder.py @@ -0,0 +1,168 @@ + +import torch +from torch import nn +from torch.nn import functional as F + +from timm.models.layers import trunc_normal_ + +from .registry import register_model +from ..utils import configurable +from .LangEncoder import build_tokenizer, build_lang_encoder +from utils.misc import prompt_engineering, get_prompt_templates + + +class LanguageEncoder(nn.Module): + + @configurable + def __init__( + self, + tokenizer, + tokenizer_type, + lang_encoder, + lang_projection, + max_token_num, + ): + super().__init__() + self.tokenizer = tokenizer + self.tokenizer_type = tokenizer_type + self.lang_encoder = lang_encoder + self.lang_proj = lang_projection + self.max_token_num = max_token_num + self.logit_scale = nn.Parameter(torch.ones([])) + + @classmethod + def from_config(cls, cfg): + tokenizer = build_tokenizer(cfg['MODEL']['TEXT']) + tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER'] + lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE']) + max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH'] + + dim_lang = cfg['MODEL']['TEXT']['WIDTH'] + dim_projection = cfg['MODEL']['DIM_PROJ'] + lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection)) + trunc_normal_(lang_projection, std=.02) + + return { + "tokenizer": tokenizer, + "tokenizer_type": tokenizer_type, + "lang_encoder": lang_encoder, + "lang_projection": lang_projection, + "max_token_num": max_token_num, + } + + def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True): + if not is_eval: + if prompt: + # randomly sample one template + arbitary_concepts = [ + prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \ + for label in range(len(class_names)) + ] + if add_bgd: + arbitary_concepts.append("A background in coco.") + else: + arbitary_concepts = class_names + + input_ids = [] + attention_masks = [] + for txt in arbitary_concepts: + tokens = self.tokenizer( + txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' + ) + tokens['input_ids'].squeeze_() + tokens['attention_mask'].squeeze_() + + input_ids.append(tokens['input_ids']) + attention_masks.append(tokens['attention_mask']) + + arbitary_tokens = torch.stack(input_ids) + arbitary_attention_masks = torch.stack(attention_masks) + + text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm) + setattr(self, '{}_text_embeddings'.format(name), text_emb) + else: + with torch.no_grad(): + def extract_mean_emb(txts): + tokens = self.tokenizer( + txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' + ) + clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm) + clss_embedding = clss_embedding.mean(dim=0) + clss_embedding /= clss_embedding.norm() + return clss_embedding + + templates = get_prompt_templates() + clss_embeddings = [] + if prompt: + for clss in class_names: + txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] + clss_embeddings.append(extract_mean_emb(txts)) + else: + clss_embeddings.append(extract_mean_emb(class_names)) + + if add_bgd: + txts = ["A background in coco."] + clss_embeddings.append(extract_mean_emb(txts)) + + text_emb = torch.stack(clss_embeddings, dim=0) + setattr(self, '{}_text_embeddings'.format(name), text_emb) + + def get_text_token_embeddings(self, txts, name='default', token=False, norm=False): + if not token: + tokens = self.tokenizer( + txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt' + ) + tokens = {key: value.cuda() for key, value in tokens.items()} + else: + tokens = txts + token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm) + ret = {"tokens": tokens, + "token_emb": token_emb, + "class_emb": class_emb,} + setattr(self, '{}_token_embeddings'.format(name), ret) + return ret + + def forward_language(self, texts, norm=True): + x = self.lang_encoder(*texts) + x = x['last_hidden_state'] + + if self.tokenizer_type == 'clip': + x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)] + else: + x = x[:, 0] + + x = x @ self.lang_proj + if norm: + x = x / (x.norm(dim=-1, keepdim=True) + 1e-7) + return x + + def forward_language_token(self, texts, norm=False): + x = self.lang_encoder(*texts) + token_x = x['last_hidden_state'] + + if self.tokenizer_type == 'clip': + class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)] + else: + class_x = token_x[:, 0] + + class_x = class_x @ self.lang_proj + token_x = token_x @ self.lang_proj + + if norm: + class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7) + token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7) + + return token_x, class_x + + def compute_similarity(self, v_emb, name='default', fake=False): + if fake: + return None + v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7) + t_emb = getattr(self, '{}_text_embeddings'.format(name)) + output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2) + return output + + +@register_model +def get_language_model(cfg, **kwargs): + return LanguageEncoder(cfg) \ No newline at end of file diff --git a/xdecoder/modules/__init__.py b/xdecoder/modules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..6bbbff85221d3e15d34b52f69706896896c47ef3 --- /dev/null +++ b/xdecoder/modules/__init__.py @@ -0,0 +1,3 @@ +from .position_encoding import * +from .attention import * +from .postprocessing import * \ No newline at end of file diff --git a/xdecoder/modules/__pycache__/__init__.cpython-38.pyc b/xdecoder/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06c8296c576ff1d2d4f508d0a8ad06b02ab46056 Binary files /dev/null and b/xdecoder/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/modules/__pycache__/attention.cpython-38.pyc b/xdecoder/modules/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8599bf5769136c530047178a65bb3c41b5bd3652 Binary files /dev/null and b/xdecoder/modules/__pycache__/attention.cpython-38.pyc differ diff --git a/xdecoder/modules/__pycache__/position_encoding.cpython-38.pyc b/xdecoder/modules/__pycache__/position_encoding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaff0366ebd20b0f273948fe43e3ccfeb58e4a16 Binary files /dev/null and b/xdecoder/modules/__pycache__/position_encoding.cpython-38.pyc differ diff --git a/xdecoder/modules/__pycache__/postprocessing.cpython-38.pyc b/xdecoder/modules/__pycache__/postprocessing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de6962a35528219c9536bfcc0e8df2914cedf92a Binary files /dev/null and b/xdecoder/modules/__pycache__/postprocessing.cpython-38.pyc differ diff --git a/xdecoder/modules/attention.py b/xdecoder/modules/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..a0eadeee1454cfbea58a96595af7c9e552088c6a --- /dev/null +++ b/xdecoder/modules/attention.py @@ -0,0 +1,489 @@ +# Code copy from PyTorch, modified by Xueyan Zou + +import warnings +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.parameter import Parameter +from torch.overrides import has_torch_function, handle_torch_function +from torch.nn.functional import pad, linear, softmax, dropout + + +def multi_head_attention_forward( + query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) + if has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, + tens_ops, + query, + key, + value, + embed_dim_to_check, + num_heads, + in_proj_weight, + in_proj_bias, + bias_k, + bias_v, + add_zero_attn, + dropout_p, + out_proj_weight, + out_proj_bias, + training=training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, + k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, + static_k=static_k, + static_v=static_v, + ) + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if not use_separate_proj_weight: + if (query is key or torch.equal(query, key)) and (key is value or torch.equal(key, value)): + # self-attention + q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif key is value or torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = linear(value, _w, _b) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) + k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]) + v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :]) + else: + q = linear(query, q_proj_weight_non_opt, in_proj_bias) + k = linear(key, k_proj_weight_non_opt, in_proj_bias) + v = linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + # assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + attn_output_weights = softmax(attn_output_weights, dim=-1) + attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +# This class exists solely for Transformer; it has an annotation stating +# that bias is never None, which appeases TorchScript +class _LinearWithBias(nn.Linear): + bias: Tensor # type: ignore + + def __init__(self, in_features: int, out_features: int) -> None: + super().__init__(in_features, out_features, bias=True) # type: ignore + + +class MultiheadAttention(nn.Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See `Attention Is All You Need `_ + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + + Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set + to :attr:`embed_dim` such that query, key, and value have the same + number of features. + + Examples:: + + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + if self._qkv_same_embed_dim is False: + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.register_parameter('in_proj_weight', None) + else: + self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) + self.register_parameter('q_proj_weight', None) + self.register_parameter('k_proj_weight', None) + self.register_parameter('v_proj_weight', None) + + if bias: + self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.out_proj = _LinearWithBias(embed_dim, embed_dim) + + if add_bias_kv: + self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) + self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shapes for inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the + source sequence length. + + If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence + length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend + the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Shapes for outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not self._qkv_same_embed_dim: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + else: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask) \ No newline at end of file diff --git a/xdecoder/modules/position_encoding.py b/xdecoder/modules/position_encoding.py new file mode 100755 index 0000000000000000000000000000000000000000..09faa117bcd04b9c3f70301347630c4ace39cac2 --- /dev/null +++ b/xdecoder/modules/position_encoding.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +## Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py +""" +Various positional encodings for the transformer. +""" +import math + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=x.dtype) + x_embed = not_mask.cumsum(2, dtype=x.dtype) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=x.dtype, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) diff --git a/xdecoder/modules/postprocessing.py b/xdecoder/modules/postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..eef2047589674fda092bebc310bd394a3db57074 --- /dev/null +++ b/xdecoder/modules/postprocessing.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +from torch.nn import functional as F + +from detectron2.structures import Instances, ROIMasks + + +# perhaps should rename to "resize_instance" +def detector_postprocess( + results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 +): + """ + Resize the output instances. + The input images are often resized when entering an object detector. + As a result, we often need the outputs of the detector in a different + resolution from its inputs. + + This function will resize the raw outputs of an R-CNN detector + to produce outputs according to the desired output resolution. + + Args: + results (Instances): the raw outputs from the detector. + `results.image_size` contains the input image resolution the detector sees. + This object might be modified in-place. + output_height, output_width: the desired output resolution. + + Returns: + Instances: the resized output from the model, based on the output resolution + """ + if isinstance(output_width, torch.Tensor): + # This shape might (but not necessarily) be tensors during tracing. + # Converts integer tensors to float temporaries to ensure true + # division is performed when computing scale_x and scale_y. + output_width_tmp = output_width.float() + output_height_tmp = output_height.float() + new_size = torch.stack([output_height, output_width]) + else: + new_size = (output_height, output_width) + output_width_tmp = output_width + output_height_tmp = output_height + + scale_x, scale_y = ( + output_width_tmp / results.image_size[1], + output_height_tmp / results.image_size[0], + ) + results = Instances(new_size, **results.get_fields()) + + if results.has("pred_boxes"): + output_boxes = results.pred_boxes + elif results.has("proposal_boxes"): + output_boxes = results.proposal_boxes + else: + output_boxes = None + assert output_boxes is not None, "Predictions must contain boxes!" + + output_boxes.scale(scale_x, scale_y) + output_boxes.clip(results.image_size) + + results = results[output_boxes.nonempty()] + + if results.has("pred_masks"): + if isinstance(results.pred_masks, ROIMasks): + roi_masks = results.pred_masks + else: + # pred_masks is a tensor of shape (N, 1, M, M) + roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) + results.pred_masks = roi_masks.to_bitmasks( + results.pred_boxes, output_height, output_width, mask_threshold + ).tensor # TODO return ROIMasks/BitMask object in the future + + if results.has("pred_keypoints"): + results.pred_keypoints[:, :, 0] *= scale_x + results.pred_keypoints[:, :, 1] *= scale_y + + return results + +def bbox_postprocess(result, input_size, img_size, output_height, output_width): + """ + result: [xc,yc,w,h] range [0,1] to [x1,y1,x2,y2] range [0,w], [0,h] + """ + if result is None: + return None + + scale = torch.tensor([input_size[1], input_size[0], input_size[1], input_size[0]])[None,:].to(result.device) + result = result.sigmoid() * scale + x1,y1,x2,y2 = result[:,0] - result[:,2]/2, result[:,1] - result[:,3]/2, result[:,0] + result[:,2]/2, result[:,1] + result[:,3]/2 + h,w = img_size + + x1 = x1.clamp(min=0, max=w) + y1 = y1.clamp(min=0, max=h) + x2 = x2.clamp(min=0, max=w) + y2 = y2.clamp(min=0, max=h) + + box = torch.stack([x1,y1,x2,y2]).permute(1,0) + scale = torch.tensor([output_width/w, output_height/h, output_width/w, output_height/h])[None,:].to(result.device) + box = box*scale + return box + +def sem_seg_postprocess(result, img_size, output_height, output_width): + """ + Return semantic segmentation predictions in the original resolution. + + The input images are often resized when entering semantic segmentor. Moreover, in same + cases, they also padded inside segmentor to be divisible by maximum network stride. + As a result, we often need the predictions of the segmentor in a different + resolution from its inputs. + + Args: + result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), + where C is the number of classes, and H, W are the height and width of the prediction. + img_size (tuple): image size that segmentor is taking as input. + output_height, output_width: the desired output resolution. + + Returns: + semantic segmentation prediction (Tensor): A tensor of the shape + (C, output_height, output_width) that contains per-pixel soft predictions. + """ + result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) + result = F.interpolate( + result, size=(output_height, output_width), mode="bilinear", align_corners=False + )[0] + return result diff --git a/xdecoder/utils/__init__.py b/xdecoder/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..4ca95fb0709a0af80e45d7fc35aa3eb31bac9f13 --- /dev/null +++ b/xdecoder/utils/__init__.py @@ -0,0 +1,4 @@ +from .config import * +from .misc import * +from .box_ops import * +from .it_contrastive import * \ No newline at end of file diff --git a/xdecoder/utils/__pycache__/__init__.cpython-38.pyc b/xdecoder/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..485f23b2eab026e23ba0be7f13d480ef5988b244 Binary files /dev/null and b/xdecoder/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/xdecoder/utils/__pycache__/box_ops.cpython-38.pyc b/xdecoder/utils/__pycache__/box_ops.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5145924c6995d3162fe5da54eb38cdbaa7ed181c Binary files /dev/null and b/xdecoder/utils/__pycache__/box_ops.cpython-38.pyc differ diff --git a/xdecoder/utils/__pycache__/config.cpython-38.pyc b/xdecoder/utils/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06b75298a9860396aa9f7e836617bd97ad5960b7 Binary files /dev/null and b/xdecoder/utils/__pycache__/config.cpython-38.pyc differ diff --git a/xdecoder/utils/__pycache__/it_contrastive.cpython-38.pyc b/xdecoder/utils/__pycache__/it_contrastive.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f900a2c5a5ee8e00110cac1b9f883212957e6df Binary files /dev/null and b/xdecoder/utils/__pycache__/it_contrastive.cpython-38.pyc differ diff --git a/xdecoder/utils/__pycache__/misc.cpython-38.pyc b/xdecoder/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2dbc16f0f1752f05c7b386aef63c4109db434eb Binary files /dev/null and b/xdecoder/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/xdecoder/utils/box_ops.py b/xdecoder/utils/box_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..42f93d5d48e25657e9f46ccef1a17064b8c192f7 --- /dev/null +++ b/xdecoder/utils/box_ops.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + +def box_xywh_to_xyxy(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [x0, y0, (x0 + x1), (y0 + y1)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) \ No newline at end of file diff --git a/xdecoder/utils/config.py b/xdecoder/utils/config.py new file mode 100755 index 0000000000000000000000000000000000000000..766bb386498f0f034485a19027d5b30b0b6d20ff --- /dev/null +++ b/xdecoder/utils/config.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import functools +import inspect + +def configurable(init_func=None, *, from_config=None): + """ + Decorate a function or a class's __init__ method so that it can be called + with a :class:`CfgNode` object using a :func:`from_config` function that translates + :class:`CfgNode` to arguments. + + Examples: + :: + # Usage 1: Decorator on __init__: + class A: + @configurable + def __init__(self, a, b=2, c=3): + pass + + @classmethod + def from_config(cls, cfg): # 'cfg' must be the first argument + # Returns kwargs to be passed to __init__ + return {"a": cfg.A, "b": cfg.B} + + a1 = A(a=1, b=2) # regular construction + a2 = A(cfg) # construct with a cfg + a3 = A(cfg, b=3, c=4) # construct with extra overwrite + + # Usage 2: Decorator on any function. Needs an extra from_config argument: + @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) + def a_func(a, b=2, c=3): + pass + + a1 = a_func(a=1, b=2) # regular call + a2 = a_func(cfg) # call with a cfg + a3 = a_func(cfg, b=3, c=4) # call with extra overwrite + + Args: + init_func (callable): a class's ``__init__`` method in usage 1. The + class must have a ``from_config`` classmethod which takes `cfg` as + the first argument. + from_config (callable): the from_config function in usage 2. It must take `cfg` + as its first argument. + """ + + if init_func is not None: + assert ( + inspect.isfunction(init_func) + and from_config is None + and init_func.__name__ == "__init__" + ), "Incorrect use of @configurable. Check API documentation for examples." + + @functools.wraps(init_func) + def wrapped(self, *args, **kwargs): + try: + from_config_func = type(self).from_config + except AttributeError as e: + raise AttributeError( + "Class with @configurable must have a 'from_config' classmethod." + ) from e + if not inspect.ismethod(from_config_func): + raise TypeError("Class with @configurable must have a 'from_config' classmethod.") + + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) + init_func(self, **explicit_args) + else: + init_func(self, *args, **kwargs) + + return wrapped + + else: + if from_config is None: + return configurable # @configurable() is made equivalent to @configurable + assert inspect.isfunction( + from_config + ), "from_config argument of configurable must be a function!" + + def wrapper(orig_func): + @functools.wraps(orig_func) + def wrapped(*args, **kwargs): + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config(from_config, *args, **kwargs) + return orig_func(**explicit_args) + else: + return orig_func(*args, **kwargs) + + wrapped.from_config = from_config + return wrapped + + return wrapper + +def _called_with_cfg(*args, **kwargs): + """ + Returns: + bool: whether the arguments contain CfgNode and should be considered + forwarded to from_config. + """ + from omegaconf import DictConfig + + if len(args) and isinstance(args[0], (dict)): + return True + if isinstance(kwargs.pop("cfg", None), (dict)): + return True + # `from_config`'s first argument is forced to be "cfg". + # So the above check covers all cases. + return False + +def _get_args_from_config(from_config_func, *args, **kwargs): + """ + Use `from_config` to obtain explicit arguments. + + Returns: + dict: arguments to be used for cls.__init__ + """ + signature = inspect.signature(from_config_func) + if list(signature.parameters.keys())[0] != "cfg": + if inspect.isfunction(from_config_func): + name = from_config_func.__name__ + else: + name = f"{from_config_func.__self__}.from_config" + raise TypeError(f"{name} must take 'cfg' as the first argument!") + support_var_arg = any( + param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] + for param in signature.parameters.values() + ) + if support_var_arg: # forward all arguments to from_config, if from_config accepts them + ret = from_config_func(*args, **kwargs) + else: + # forward supported arguments to from_config + supported_arg_names = set(signature.parameters.keys()) + extra_kwargs = {} + for name in list(kwargs.keys()): + if name not in supported_arg_names: + extra_kwargs[name] = kwargs.pop(name) + ret = from_config_func(*args, **kwargs) + # forward the other arguments to __init__ + ret.update(extra_kwargs) + return ret \ No newline at end of file diff --git a/xdecoder/utils/it_contrastive.py b/xdecoder/utils/it_contrastive.py new file mode 100755 index 0000000000000000000000000000000000000000..b30fd2dae6221c2c244e5b48109e282a6e2e1533 --- /dev/null +++ b/xdecoder/utils/it_contrastive.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +def is_dist_initialized(): + return torch.distributed.is_initialized() + +def get_world_size(): + if is_dist_initialized(): + return torch.distributed.get_world_size() + return 1 + +def all_gather_grad(x): + if get_world_size() > 1: + all_x = [torch.zeros_like(x) for _ in range(get_world_size())] + torch.distributed.all_gather(all_x, x) + all_x[torch.distributed.get_rank()] = x + x = torch.cat(all_x, dim=0) + return x + +@torch.no_grad() +def all_gather_nograd(tensor): + # from albef + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + if get_world_size() > 1: + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + tensor = torch.cat(tensors_gather, dim=0) + return tensor + +def image_text_contrastive_loss(image_feat, text_feat, temperature, image_id=None, text_id=None): + # add the following 4 lines + image_feat = all_gather_grad(image_feat) + text_feat = all_gather_grad(text_feat) + + logits = torch.matmul(image_feat, text_feat.t()) + logits /= temperature + + if image_id is None and text_id is None: + gt = torch.arange(logits.shape[0], device=logits.device) + loss1 = F.cross_entropy(logits, gt) + loss2 = F.cross_entropy(logits.t(), gt) + else: + image_id = all_gather_grad(image_id) + text_id = all_gather_grad(text_id) + + gt_image = image_id.reshape((-1, 1)) == image_id.reshape((1, -1)) + gt_text = text_id.reshape((-1, 1)) == text_id.reshape((1, -1)) + gt = torch.logical_or(gt_image, gt_text) + + loss1 = -torch.sum(gt * F.log_softmax(logits, dim=1)) / gt.sum() + loss2 = -torch.sum(gt.t() * F.log_softmax(logits.t(), dim=1)) / gt.sum() + + return (loss1 + loss2) / 2 * get_world_size() # scale it up by the number of GPUs diff --git a/xdecoder/utils/misc.py b/xdecoder/utils/misc.py new file mode 100755 index 0000000000000000000000000000000000000000..e7bfa08060344fedcb1d5017b932a3c16fc5bc86 --- /dev/null +++ b/xdecoder/utils/misc.py @@ -0,0 +1,157 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py +# Modified by Xueyan Zou +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +from typing import List, Optional + +import torch +import torch.distributed as dist +import torchvision +from torch import Tensor + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], : img.shape[2]] = False + elif tensor_list[0].ndim == 2: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(txt.shape) for txt in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, l = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, l), dtype=torch.bool, device=device) + for txt, pad_txt, m in zip(tensor_list, tensor, mask): + pad_txt[: txt.shape[0], : txt.shape[1]] = txt + m[: txt.shape[1]] = False + else: + raise ValueError("not supported") + return NestedTensor(tensor, mask) + +def _collate_and_pad_divisibility(tensor_list: list, div=32): + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max( + torch.tensor([img.shape[i] for img in tensor_list]).to(torch.float32) + ).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + c,h,w = max_size + pad_h = (div - h % div) if h % div != 0 else 0 + pad_w = (div - w % div) if w % div != 0 else 0 + max_size = (c,h+pad_h,w+pad_w) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + return padded_imgs + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max( + torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) + ).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True \ No newline at end of file diff --git a/xdecoder_focalt_last.pt b/xdecoder_focalt_last.pt new file mode 100644 index 0000000000000000000000000000000000000000..9cbf4b0274c0eb16d1921a687ab84618e70c3630 --- /dev/null +++ b/xdecoder_focalt_last.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9ddc9672a1fb8c0e463b4bc0c0e788739d08899b89c5cb901e581e3bbda6fb6d +size 658330805 diff --git a/xdecoder_focalt_last_novg.pt b/xdecoder_focalt_last_novg.pt new file mode 100644 index 0000000000000000000000000000000000000000..81f3b4720da031198269851fc5288a3599416819 --- /dev/null +++ b/xdecoder_focalt_last_novg.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9d18e951784e9d6d84897cd1d87849b0c69333dafe8e5b358b284f4282990d0 +size 658330805