Awiny commited on
Commit
c3a1897
Β·
1 Parent(s): 353fa54

first version submission

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. app.py +59 -4
  2. main_gradio.py +84 -0
  3. models/__pycache__/blip2_model.cpython-38.pyc +0 -0
  4. models/__pycache__/blip2_model.cpython-39.pyc +0 -0
  5. models/__pycache__/controlnet_model.cpython-38.pyc +0 -0
  6. models/__pycache__/gpt_model.cpython-38.pyc +0 -0
  7. models/__pycache__/grit_model.cpython-38.pyc +0 -0
  8. models/__pycache__/image_text_transformation.cpython-38.pyc +0 -0
  9. models/__pycache__/image_text_transformation.cpython-39.pyc +0 -0
  10. models/__pycache__/region_semantic.cpython-38.pyc +0 -0
  11. models/blip2_model.py +38 -0
  12. models/controlnet_model.py +51 -0
  13. models/gpt_model.py +40 -0
  14. models/grit_model.py +26 -0
  15. models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc +0 -0
  16. models/grit_src/configs/Base.yaml +77 -0
  17. models/grit_src/configs/GRiT_B_DenseCap.yaml +20 -0
  18. models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml +23 -0
  19. models/grit_src/configs/GRiT_B_ObjectDet.yaml +20 -0
  20. models/grit_src/configs/GRiT_H_ObjectDet.yaml +21 -0
  21. models/grit_src/configs/GRiT_L_ObjectDet.yaml +20 -0
  22. models/grit_src/grit/__init__.py +7 -0
  23. models/grit_src/grit/__pycache__/__init__.cpython-38.pyc +0 -0
  24. models/grit_src/grit/__pycache__/config.cpython-38.pyc +0 -0
  25. models/grit_src/grit/__pycache__/predictor.cpython-38.pyc +0 -0
  26. models/grit_src/grit/config.py +50 -0
  27. models/grit_src/grit/custom_solver.py +88 -0
  28. models/grit_src/grit/data/__pycache__/custom_build_augmentation.cpython-38.pyc +0 -0
  29. models/grit_src/grit/data/__pycache__/custom_dataset_mapper.cpython-38.pyc +0 -0
  30. models/grit_src/grit/data/custom_build_augmentation.py +44 -0
  31. models/grit_src/grit/data/custom_dataset_dataloader.py +250 -0
  32. models/grit_src/grit/data/custom_dataset_mapper.py +149 -0
  33. models/grit_src/grit/data/datasets/__pycache__/grit_coco.cpython-38.pyc +0 -0
  34. models/grit_src/grit/data/datasets/__pycache__/object365.cpython-38.pyc +0 -0
  35. models/grit_src/grit/data/datasets/__pycache__/vg.cpython-38.pyc +0 -0
  36. models/grit_src/grit/data/datasets/grit_coco.py +112 -0
  37. models/grit_src/grit/data/datasets/object365.py +111 -0
  38. models/grit_src/grit/data/datasets/vg.py +98 -0
  39. models/grit_src/grit/data/transforms/__pycache__/custom_augmentation_impl.cpython-38.pyc +0 -0
  40. models/grit_src/grit/data/transforms/__pycache__/custom_transform.cpython-38.pyc +0 -0
  41. models/grit_src/grit/data/transforms/custom_augmentation_impl.py +52 -0
  42. models/grit_src/grit/data/transforms/custom_transform.py +115 -0
  43. models/grit_src/grit/evaluation/eval.py +156 -0
  44. models/grit_src/grit/modeling/__pycache__/soft_nms.cpython-38.pyc +0 -0
  45. models/grit_src/grit/modeling/backbone/__pycache__/utils.cpython-38.pyc +0 -0
  46. models/grit_src/grit/modeling/backbone/__pycache__/vit.cpython-38.pyc +0 -0
  47. models/grit_src/grit/modeling/backbone/utils.py +186 -0
  48. models/grit_src/grit/modeling/backbone/vit.py +538 -0
  49. models/grit_src/grit/modeling/meta_arch/__pycache__/grit.cpython-38.pyc +0 -0
  50. models/grit_src/grit/modeling/meta_arch/grit.py +66 -0
app.py CHANGED
@@ -1,7 +1,62 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ import base64
6
+ from io import BytesIO
7
+ from models.image_text_transformation import ImageTextTransformation
8
 
9
+ def pil_image_to_base64(image):
10
+ buffered = BytesIO()
11
+ image.save(buffered, format="JPEG")
12
+ img_str = base64.b64encode(buffered.getvalue()).decode()
13
+ return img_str
14
 
15
+ def add_logo():
16
+ with open("examples/logo.png", "rb") as f:
17
+ logo_base64 = base64.b64encode(f.read()).decode()
18
+ return logo_base64
19
+
20
+ def process_image(image_src, processor):
21
+ gen_text = processor.image_to_text(image_src)
22
+ gen_image = processor.text_to_image(gen_text)
23
+ gen_image_str = pil_image_to_base64(gen_image)
24
+ # Combine the outputs into a single HTML output
25
+ custom_output = f'''
26
+ <h2>Image->Text->Image:</h2>
27
+ <div style="display: flex; flex-wrap: wrap;">
28
+ <div style="flex: 1;">
29
+ <h3>Image2Text</h3>
30
+ <p>{gen_text}</p>
31
+ </div>
32
+ <div style="flex: 1;">
33
+ <h3>Text2Image</h3>
34
+ <img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
35
+ </div>
36
+ </div>
37
+ '''
38
+
39
+ return custom_output
40
+
41
+ processor = ImageTextTransformation()
42
+
43
+ # Create Gradio input and output components
44
+ image_input = gr.inputs.Image(type='filepath', label="Input Image")
45
+
46
+ logo_base64 = add_logo()
47
+ # Create the title with the logo
48
+ title_with_logo = f'<img src="data:image/jpeg;base64,{logo_base64}" width="400" style="vertical-align: middle;"> Understanding Image with Text'
49
+
50
+ # Create Gradio interface
51
+ interface = gr.Interface(
52
+ fn=lambda image: process_image(image, processor), # Pass the processor object using a lambda function
53
+ inputs=image_input,
54
+ outputs=gr.outputs.HTML(),
55
+ title=title_with_logo,
56
+ description="""
57
+ This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
58
+ """
59
+ )
60
+
61
+ # Launch the interface
62
+ interface.launch()
main_gradio.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ import base64
6
+ from io import BytesIO
7
+ from models.image_text_transformation import ImageTextTransformation
8
+
9
+ def pil_image_to_base64(image):
10
+ buffered = BytesIO()
11
+ image.save(buffered, format="JPEG")
12
+ img_str = base64.b64encode(buffered.getvalue()).decode()
13
+ return img_str
14
+
15
+ def add_logo():
16
+ with open("examples/logo.png", "rb") as f:
17
+ logo_base64 = base64.b64encode(f.read()).decode()
18
+ return logo_base64
19
+
20
+ def process_image(image_src, processor):
21
+ gen_text = processor.image_to_text(image_src)
22
+ gen_image = processor.text_to_image(gen_text)
23
+ gen_image_str = pil_image_to_base64(gen_image)
24
+ # Combine the outputs into a single HTML output
25
+ custom_output = f'''
26
+ <h2>Image->Text->Image:</h2>
27
+ <div style="display: flex; flex-wrap: wrap;">
28
+ <div style="flex: 1;">
29
+ <h3>Image2Text</h3>
30
+ <p>{gen_text}</p>
31
+ </div>
32
+ <div style="flex: 1;">
33
+ <h3>Text2Image</h3>
34
+ <img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
35
+ </div>
36
+ </div>
37
+ <h2>Using Source Image to do Retrieval on COCO:</h2>
38
+ <div style="display: flex; flex-wrap: wrap;">
39
+ <div style="flex: 1;">
40
+ <h3>Retrieval Top-3 Text</h3>
41
+ <p>{gen_text}</p>
42
+ </div>
43
+ <div style="flex: 1;">
44
+ <h3>Retrieval Top-3 Image</h3>
45
+ <img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
46
+ </div>
47
+ </div>
48
+ <h2>Using Generated texts to do Retrieval on COCO:</h2>
49
+ <div style="display: flex; flex-wrap: wrap;">
50
+ <div style="flex: 1;">
51
+ <h3>Retrieval Top-3 Text</h3>
52
+ <p>{gen_text}</p>
53
+ </div>
54
+ <div style="flex: 1;">
55
+ <h3>Retrieval Top-3 Image</h3>
56
+ <img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
57
+ </div>
58
+ </div>
59
+ '''
60
+
61
+ return custom_output
62
+
63
+ processor = ImageTextTransformation()
64
+
65
+ # Create Gradio input and output components
66
+ image_input = gr.inputs.Image(type='filepath', label="Input Image")
67
+
68
+ logo_base64 = add_logo()
69
+ # Create the title with the logo
70
+ title_with_logo = f'<img src="data:image/jpeg;base64,{logo_base64}" width="400" style="vertical-align: middle;"> Understanding Image with Text'
71
+
72
+ # Create Gradio interface
73
+ interface = gr.Interface(
74
+ fn=lambda image: process_image(image, processor), # Pass the processor object using a lambda function
75
+ inputs=image_input,
76
+ outputs=gr.outputs.HTML(),
77
+ title=title_with_logo,
78
+ description="""
79
+ This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
80
+ """
81
+ )
82
+
83
+ # Launch the interface
84
+ interface.launch()
models/__pycache__/blip2_model.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
models/__pycache__/blip2_model.cpython-39.pyc ADDED
Binary file (1.88 kB). View file
 
models/__pycache__/controlnet_model.cpython-38.pyc ADDED
Binary file (1.88 kB). View file
 
models/__pycache__/gpt_model.cpython-38.pyc ADDED
Binary file (2.28 kB). View file
 
models/__pycache__/grit_model.cpython-38.pyc ADDED
Binary file (1.38 kB). View file
 
models/__pycache__/image_text_transformation.cpython-38.pyc ADDED
Binary file (2.55 kB). View file
 
models/__pycache__/image_text_transformation.cpython-39.pyc ADDED
Binary file (2.55 kB). View file
 
models/__pycache__/region_semantic.cpython-38.pyc ADDED
Binary file (2.2 kB). View file
 
models/blip2_model.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import requests
3
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
4
+ import torch
5
+
6
+
7
+ class ImageCaptioning:
8
+ def __init__(self) -> None:
9
+ self.device = None
10
+ # self.processor, self.model = None, None
11
+ self.processor, self.model = self.initialize_model()
12
+
13
+ def initialize_model(self):
14
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ self.device = "cpu" # for low gpu memory devices
16
+ if self.device == 'cpu':
17
+ self.data_type = torch.float32
18
+ else:
19
+ self.data_type = torch.float16
20
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
21
+ model = Blip2ForConditionalGeneration.from_pretrained(
22
+ "Salesforce/blip2-opt-2.7b", torch_dtype=self.data_type
23
+ )
24
+ model.to(self.device)
25
+ return processor, model
26
+
27
+ def image_caption(self, image_src):
28
+ image = Image.open(image_src)
29
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
30
+ generated_ids = self.model.generate(**inputs)
31
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
+ print('*'*100 + '\nStep1, BLIP2 caption:')
33
+ print(generated_text)
34
+ print('\n' + '*'*100)
35
+ return generated_text
36
+
37
+ def image_caption_debug(self, image_src):
38
+ return "A dish with salmon, broccoli, and something yellow."
models/controlnet_model.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from diffusers import (
6
+ StableDiffusionControlNetPipeline,
7
+ ControlNetModel,
8
+ UniPCMultistepScheduler,
9
+ )
10
+
11
+
12
+ class TextToImage:
13
+ def __init__(self):
14
+ # self.model = None
15
+ self.model = self.initialize_model()
16
+
17
+ def initialize_model(self):
18
+ controlnet = ControlNetModel.from_pretrained(
19
+ "fusing/stable-diffusion-v1-5-controlnet-canny",
20
+ torch_dtype=torch.float16,
21
+ )
22
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
23
+ "runwayml/stable-diffusion-v1-5",
24
+ controlnet=controlnet,
25
+ safety_checker=None,
26
+ torch_dtype=torch.float16,
27
+ )
28
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(
29
+ pipeline.scheduler.config
30
+ )
31
+ pipeline.enable_model_cpu_offload()
32
+ return pipeline
33
+
34
+ @staticmethod
35
+ def preprocess_image(image):
36
+ image = np.array(image)
37
+ low_threshold = 100
38
+ high_threshold = 200
39
+ image = cv2.Canny(image, low_threshold, high_threshold)
40
+ image = np.stack([image, image, image], axis=2)
41
+ image = Image.fromarray(image)
42
+ return image
43
+
44
+ def text_to_image(self, text, image):
45
+ image = self.preprocess_image(image)
46
+ generated_image = self.model(text, image, num_inference_steps=20).images[0]
47
+ return generated_image
48
+
49
+ def text_to_image_debug(self, text, image):
50
+ print("text_to_image_debug")
51
+ return image
models/gpt_model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+ class ImageToText:
4
+ def __init__(self, api_key):
5
+ self.template = self.initialize_template()
6
+ openai.api_key = api_key
7
+
8
+ def initialize_template(self):
9
+ prompt_prefix_1 = """Generate only an informative and nature paragraph based on the given information(a,b,c,d):\n"""
10
+ prompt_prefix_2 = """\n a. Image Resolution: """
11
+ prompt_prefix_3 = """\n b. Image Caption: """
12
+ prompt_prefix_4 = """\n c. Dense Caption: """
13
+ prompt_prefix_5 = """\n d. Region Semantic: """
14
+ prompt_suffix = """\n There are some rules:
15
+ Show object, color and position.
16
+ Use nouns rather than coordinates to show position information of each object.
17
+ No more than 7 sentences.
18
+ Only use one paragraph.
19
+ Do not appear number.
20
+ """
21
+ template = f"{prompt_prefix_1}{prompt_prefix_2}{{width}}X{{height}}{prompt_prefix_3}{{caption}}{prompt_prefix_4}{{dense_caption}}{prompt_prefix_5}{{region_semantic}}{prompt_suffix}"
22
+ return template
23
+
24
+ def paragraph_summary_with_gpt(self, caption, dense_caption, region_semantic, width, height):
25
+ question = self.template.format(width=width, height=height, caption=caption, dense_caption=dense_caption, region_semantic=region_semantic)
26
+ print('*'*100)
27
+ print("question:", question)
28
+ completion = openai.ChatCompletion.create(
29
+ model="gpt-3.5-turbo",
30
+ messages = [
31
+ {"role": "user", "content" : question}]
32
+ )
33
+ print("chatgpt response:", completion['choices'][0]['message']['content'])
34
+ print('*'*100)
35
+ return completion['choices'][0]['message']['content']
36
+
37
+ def paragraph_summary_with_gpt_debug(self, caption, dense_caption, width, height):
38
+ question = self.template.format(width=width, height=height, caption=caption, dense_caption=dense_caption)
39
+ print("paragraph_summary_with_gpt_debug:")
40
+ return question
models/grit_model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from models.grit_src.image_dense_captions import image_caption_api
3
+
4
+ class DenseCaptioning():
5
+ def __init__(self) -> None:
6
+ self.model = None
7
+
8
+
9
+ def initialize_model(self):
10
+ pass
11
+
12
+ def image_dense_caption_debug(self, image_src):
13
+ dense_caption = """
14
+ 1. the broccoli is green, [0, 0, 333, 325];
15
+ 2. a piece of broccoli, [0, 147, 143, 324];
16
+ 3. silver fork on plate, [4, 547, 252, 612];
17
+ """
18
+ return dense_caption
19
+
20
+ def image_dense_caption(self, image_src):
21
+ dense_caption = image_caption_api(image_src)
22
+ print("Step2, Dense Caption:\n")
23
+ print(dense_caption)
24
+ print('\n'+'*'*100)
25
+ return dense_caption
26
+
models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc ADDED
Binary file (2.54 kB). View file
 
models/grit_src/configs/Base.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ META_ARCHITECTURE: "GRiT"
3
+ MASK_ON: True
4
+ PROPOSAL_GENERATOR:
5
+ NAME: "CenterNet"
6
+ FPN:
7
+ IN_FEATURES: ["layer3", "layer4", "layer5"]
8
+ PIXEL_MEAN: [123.675, 116.280, 103.530]
9
+ PIXEL_STD: [58.395, 57.12, 57.375]
10
+ ROI_HEADS:
11
+ NAME: GRiTROIHeadsAndTextDecoder
12
+ IN_FEATURES: ["p3", "p4", "p5"]
13
+ IOU_THRESHOLDS: [0.6]
14
+ NUM_CLASSES: 1
15
+ SCORE_THRESH_TEST: 0.02
16
+ NMS_THRESH_TEST: 0.5
17
+ OBJECT_FEAT_POOLER_RES: 14
18
+ ROI_BOX_CASCADE_HEAD:
19
+ IOUS: [0.6, 0.7, 0.8]
20
+ ROI_BOX_HEAD:
21
+ NAME: "FastRCNNConvFCHead"
22
+ NUM_FC: 2
23
+ POOLER_RESOLUTION: 7
24
+ CLS_AGNOSTIC_BBOX_REG: True
25
+ MULT_PROPOSAL_SCORE: True
26
+ ROI_MASK_HEAD:
27
+ NAME: "MaskRCNNConvUpsampleHead"
28
+ NUM_CONV: 4
29
+ POOLER_RESOLUTION: 14
30
+ CLS_AGNOSTIC_MASK: True
31
+ CENTERNET:
32
+ NUM_CLASSES: 1
33
+ REG_WEIGHT: 1.
34
+ NOT_NORM_REG: True
35
+ ONLY_PROPOSAL: True
36
+ WITH_AGN_HM: True
37
+ INFERENCE_TH: 0.0001
38
+ PRE_NMS_TOPK_TRAIN: 4000
39
+ POST_NMS_TOPK_TRAIN: 2000
40
+ PRE_NMS_TOPK_TEST: 1000
41
+ POST_NMS_TOPK_TEST: 256
42
+ NMS_TH_TRAIN: 0.9
43
+ NMS_TH_TEST: 0.9
44
+ POS_WEIGHT: 0.5
45
+ NEG_WEIGHT: 0.5
46
+ IGNORE_HIGH_FP: 0.85
47
+ DATASETS:
48
+ TRAIN: ("coco_2017_train",)
49
+ TEST: ("coco_2017_val",)
50
+ DATALOADER:
51
+ SAMPLER_TRAIN: "MultiDatasetSampler"
52
+ DATASET_RATIO: [1]
53
+ DATASET_INPUT_SIZE: [1024]
54
+ DATASET_INPUT_SCALE: [[0.1, 2.0]]
55
+ FILTER_EMPTY_ANNOTATIONS: False
56
+ NUM_WORKERS: 8
57
+ TEST:
58
+ DETECTIONS_PER_IMAGE: 256
59
+ SOLVER:
60
+ LR_SCHEDULER_NAME: "WarmupCosineLR"
61
+ CHECKPOINT_PERIOD: 10000
62
+ WARMUP_ITERS: 1000
63
+ WARMUP_FACTOR: 0.001
64
+ USE_CUSTOM_SOLVER: True
65
+ OPTIMIZER: "ADAMW"
66
+ MAX_ITER: 180000
67
+ IMS_PER_BATCH: 64
68
+ BASE_LR: 0.00008
69
+ VIT_LAYER_DECAY: True
70
+ CLIP_GRADIENTS:
71
+ ENABLED: True
72
+ INPUT:
73
+ FORMAT: RGB
74
+ CUSTOM_AUG: EfficientDetResizeCrop
75
+ TRAIN_SIZE: 640
76
+ USE_ACT_CHECKPOINT: True
77
+ VERSION: 2
models/grit_src/configs/GRiT_B_DenseCap.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["DenseCap"]
4
+ TEST_TASK: "DenseCap"
5
+ MASK_ON: False
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: False
8
+ BEAM_SIZE: 1
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone
12
+ VIT_LAYERS: 12
13
+ SOLVER:
14
+ VIT_LAYER_DECAY_RATE: 0.7
15
+ DATASETS:
16
+ TRAIN: ("vg_train",)
17
+ TEST: ("vg_test",)
18
+ DATALOADER:
19
+ DATASET_BS: 2
20
+ OUTPUT_DIR: "./output/GRiT_B_DenseCap"
models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["ObjectDet", "DenseCap"]
4
+ TEST_TASK: "DenseCap" # DenseCap or ObjectDet: Choose one for testing
5
+ MASK_ON: True
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: False
8
+ BEAM_SIZE: 1
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone
12
+ VIT_LAYERS: 12
13
+ SOLVER:
14
+ VIT_LAYER_DECAY_RATE: 0.7
15
+ DATASETS:
16
+ TRAIN: ("GRiT_coco2017_train", "vg_train")
17
+ TEST: ("coco_2017_test-dev",)
18
+ DATALOADER:
19
+ DATASET_RATIO: [1, 1]
20
+ DATASET_BS: 2
21
+ DATASET_INPUT_SIZE: [1024, 1024]
22
+ DATASET_INPUT_SCALE: [[0.1, 2.0], [0.1, 2.0]]
23
+ OUTPUT_DIR: "./output/GRiT_B_DenseCap_ObjectDet"
models/grit_src/configs/GRiT_B_ObjectDet.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["ObjectDet"]
4
+ TEST_TASK: "ObjectDet"
5
+ MASK_ON: True
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: True
8
+ BEAM_SIZE: 3
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_base.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone
12
+ VIT_LAYERS: 12
13
+ SOLVER:
14
+ VIT_LAYER_DECAY_RATE: 0.7
15
+ DATASETS:
16
+ TRAIN: ("GRiT_coco2017_train",)
17
+ TEST: ("coco_2017_val",)
18
+ DATALOADER:
19
+ DATASET_BS: 2
20
+ OUTPUT_DIR: "./output/GRiT_B_ObjectDet"
models/grit_src/configs/GRiT_H_ObjectDet.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["ObjectDet"]
4
+ TEST_TASK: "ObjectDet"
5
+ MASK_ON: True
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: True
8
+ BEAM_SIZE: 3
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_huge_p14to16.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone_huge
12
+ VIT_LAYERS: 32
13
+ SOLVER:
14
+ MAX_ITER: 135000
15
+ VIT_LAYER_DECAY_RATE: 0.9
16
+ DATASETS:
17
+ TRAIN: ("GRiT_coco2017_train",)
18
+ TEST: ("coco_2017_val",)
19
+ DATALOADER:
20
+ DATASET_BS: 1
21
+ OUTPUT_DIR: "./output/GRiT_H_ObjectDet"
models/grit_src/configs/GRiT_L_ObjectDet.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _BASE_: "Base.yaml"
2
+ MODEL:
3
+ TRAIN_TASK: ["ObjectDet"]
4
+ TEST_TASK: "ObjectDet"
5
+ MASK_ON: True
6
+ ROI_HEADS:
7
+ SOFT_NMS_ENABLED: True
8
+ BEAM_SIZE: 3
9
+ WEIGHTS: "detectron2://ImageNetPretrained/MAE/mae_pretrain_vit_large.pth"
10
+ BACKBONE:
11
+ NAME: build_vit_fpn_backbone_large
12
+ VIT_LAYERS: 24
13
+ SOLVER:
14
+ VIT_LAYER_DECAY_RATE: 0.8
15
+ DATASETS:
16
+ TRAIN: ("GRiT_coco2017_train",)
17
+ TEST: ("coco_2017_val",)
18
+ DATALOADER:
19
+ DATASET_BS: 1
20
+ OUTPUT_DIR: "./output/GRiT_L_ObjectDet"
models/grit_src/grit/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .modeling.meta_arch import grit
2
+ from .modeling.roi_heads import grit_roi_heads
3
+ from .modeling.backbone import vit
4
+
5
+ from .data.datasets import object365
6
+ from .data.datasets import vg
7
+ from .data.datasets import grit_coco
models/grit_src/grit/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (405 Bytes). View file
 
models/grit_src/grit/__pycache__/config.cpython-38.pyc ADDED
Binary file (1.4 kB). View file
 
models/grit_src/grit/__pycache__/predictor.cpython-38.pyc ADDED
Binary file (2.65 kB). View file
 
models/grit_src/grit/config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from detectron2.config import CfgNode as CN
2
+
3
+
4
+ def add_grit_config(cfg):
5
+ _C = cfg
6
+
7
+ _C.MODEL.BEAM_SIZE = 1
8
+ _C.MODEL.TRAIN_TASK = ["ObjectDet", "DenseCap"]
9
+ _C.MODEL.TEST_TASK = "DenseCap" # This can be varied if the model is jointly trained on multiple tasks
10
+
11
+ _C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use
12
+ _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False
13
+
14
+ _C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0
15
+ _C.MODEL.ROI_HEADS.OBJECT_FEAT_POOLER_RES = 14
16
+ _C.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
17
+
18
+ # Backbones
19
+ _C.MODEL.VIT_LAYERS = 12
20
+
21
+ # Text Decoder
22
+ _C.TEXT_DECODER = CN()
23
+ _C.TEXT_DECODER.VOCAB_SIZE = 30522
24
+ _C.TEXT_DECODER.HIDDEN_SIZE = 768
25
+ _C.TEXT_DECODER.NUM_LAYERS = 6
26
+ _C.TEXT_DECODER.ATTENTION_HEADS = 12
27
+ _C.TEXT_DECODER.FEEDFORWARD_SIZE = 768 * 4
28
+
29
+ # Multi-dataset dataloader
30
+ _C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio
31
+ _C.DATALOADER.DATASET_BS = 1
32
+ _C.DATALOADER.DATASET_INPUT_SIZE = [1024, 1024]
33
+ _C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.1, 2.0)]
34
+ _C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (640, 800)]
35
+ _C.DATALOADER.DATASET_MAX_SIZES = [1333, 1333]
36
+
37
+ _C.SOLVER.USE_CUSTOM_SOLVER = True
38
+ _C.SOLVER.OPTIMIZER = 'ADAMW'
39
+ _C.SOLVER.VIT_LAYER_DECAY = True
40
+ _C.SOLVER.VIT_LAYER_DECAY_RATE = 0.7
41
+
42
+ _C.INPUT.CUSTOM_AUG = 'EfficientDetResizeCrop'
43
+ _C.INPUT.TRAIN_SIZE = 1024
44
+ _C.INPUT.TEST_SIZE = 1024
45
+ _C.INPUT.SCALE_RANGE = (0.1, 2.)
46
+ # 'default' for fixed short / long edge
47
+ _C.INPUT.TEST_INPUT_TYPE = 'default'
48
+
49
+ _C.FIND_UNUSED_PARAM = True
50
+ _C.USE_ACT_CHECKPOINT = True
models/grit_src/grit/custom_solver.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/custom_solver.py
3
+ import itertools
4
+ from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
5
+ import torch
6
+
7
+ from detectron2.config import CfgNode
8
+
9
+ from detectron2.solver.build import maybe_add_gradient_clipping
10
+
11
+
12
+ def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
13
+ params: List[Dict[str, Any]] = []
14
+ memo: Set[torch.nn.parameter.Parameter] = set()
15
+ optimizer_type = cfg.SOLVER.OPTIMIZER
16
+
17
+ for key, value in model.named_parameters(recurse=True):
18
+ if not value.requires_grad:
19
+ continue
20
+ # Avoid duplicating parameters
21
+ if value in memo:
22
+ continue
23
+ memo.add(value)
24
+ lr = cfg.SOLVER.BASE_LR
25
+ weight_decay = cfg.SOLVER.WEIGHT_DECAY
26
+
27
+ if cfg.SOLVER.VIT_LAYER_DECAY:
28
+ lr = lr * get_vit_lr_decay_rate(key, cfg.SOLVER.VIT_LAYER_DECAY_RATE, cfg.MODEL.VIT_LAYERS)
29
+
30
+ param = {"params": [value], "lr": lr}
31
+ if optimizer_type != 'ADAMW':
32
+ param['weight_decay'] = weight_decay
33
+ params += [param]
34
+
35
+ def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
36
+ # detectron2 doesn't have full model gradient clipping now
37
+ clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
38
+ enable = (
39
+ cfg.SOLVER.CLIP_GRADIENTS.ENABLED
40
+ and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
41
+ and clip_norm_val > 0.0
42
+ )
43
+
44
+ class FullModelGradientClippingOptimizer(optim):
45
+ def step(self, closure=None):
46
+ all_params = itertools.chain(*[x["params"] for x in self.param_groups])
47
+ torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
48
+ super().step(closure=closure)
49
+
50
+ return FullModelGradientClippingOptimizer if enable else optim
51
+
52
+
53
+ if optimizer_type == 'SGD':
54
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
55
+ params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
56
+ nesterov=cfg.SOLVER.NESTEROV
57
+ )
58
+ elif optimizer_type == 'ADAMW':
59
+ optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
60
+ params, cfg.SOLVER.BASE_LR,
61
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY
62
+ )
63
+ else:
64
+ raise NotImplementedError(f"no optimizer type {optimizer_type}")
65
+ if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
66
+ optimizer = maybe_add_gradient_clipping(cfg, optimizer)
67
+ return optimizer
68
+
69
+
70
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
71
+ """
72
+ Calculate lr decay rate for different ViT blocks.
73
+ Args:
74
+ name (string): parameter name.
75
+ lr_decay_rate (float): base lr decay rate.
76
+ num_layers (int): number of ViT blocks.
77
+
78
+ Returns:
79
+ lr decay rate for the given parameter.
80
+ """
81
+ layer_id = num_layers + 1
82
+ if name.startswith("backbone"):
83
+ if ".pos_embed" in name or ".patch_embed" in name:
84
+ layer_id = 0
85
+ elif ".blocks." in name and ".residual." not in name:
86
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
87
+
88
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
models/grit_src/grit/data/__pycache__/custom_build_augmentation.cpython-38.pyc ADDED
Binary file (1.21 kB). View file
 
models/grit_src/grit/data/__pycache__/custom_dataset_mapper.cpython-38.pyc ADDED
Binary file (5.68 kB). View file
 
models/grit_src/grit/data/custom_build_augmentation.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ from detectron2.data import transforms as T
3
+ from .transforms.custom_augmentation_impl import EfficientDetResizeCrop
4
+
5
+
6
+ def build_custom_augmentation(cfg, is_train, scale=None, size=None, \
7
+ min_size=None, max_size=None):
8
+ """
9
+ Create a list of default :class:`Augmentation` from config.
10
+ Now it includes resizing and flipping.
11
+
12
+ Returns:
13
+ list[Augmentation]
14
+ """
15
+ if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge':
16
+ if is_train:
17
+ min_size = cfg.INPUT.MIN_SIZE_TRAIN if min_size is None else min_size
18
+ max_size = cfg.INPUT.MAX_SIZE_TRAIN if max_size is None else max_size
19
+ sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
20
+ else:
21
+ min_size = cfg.INPUT.MIN_SIZE_TEST
22
+ max_size = cfg.INPUT.MAX_SIZE_TEST
23
+ sample_style = "choice"
24
+ augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
25
+ elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
26
+ if is_train:
27
+ scale = cfg.INPUT.SCALE_RANGE if scale is None else scale
28
+ size = cfg.INPUT.TRAIN_SIZE if size is None else size
29
+ else:
30
+ scale = (1, 1)
31
+ size = cfg.INPUT.TEST_SIZE
32
+ augmentation = [EfficientDetResizeCrop(size, scale)]
33
+ else:
34
+ assert 0, cfg.INPUT.CUSTOM_AUG
35
+
36
+ if is_train:
37
+ augmentation.append(T.RandomFlip())
38
+ return augmentation
39
+
40
+
41
+ build_custom_transform_gen = build_custom_augmentation
42
+ """
43
+ Alias for backward-compatibility.
44
+ """
models/grit_src/grit/data/custom_dataset_dataloader.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/data/custom_dataset_dataloader.py
3
+ import operator
4
+ import torch
5
+ import torch.utils.data
6
+ from detectron2.utils.comm import get_world_size
7
+
8
+ from detectron2.config import configurable
9
+ from torch.utils.data.sampler import BatchSampler, Sampler
10
+ from detectron2.data.common import DatasetFromList, MapDataset
11
+ from detectron2.data.dataset_mapper import DatasetMapper
12
+ from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
13
+ from detectron2.data.samplers import TrainingSampler
14
+ from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
15
+ from detectron2.data.build import filter_images_with_only_crowd_annotations
16
+ from detectron2.data.build import filter_images_with_few_keypoints
17
+ from detectron2.data.build import check_metadata_consistency
18
+ from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
19
+ from detectron2.utils import comm
20
+ import itertools
21
+ from typing import Optional
22
+
23
+
24
+ def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
25
+ sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
26
+ if 'MultiDataset' in sampler_name:
27
+ dataset_dicts = get_detection_dataset_dicts_with_source(
28
+ cfg.DATASETS.TRAIN,
29
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
30
+ min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
31
+ if cfg.MODEL.KEYPOINT_ON else 0,
32
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
33
+ )
34
+ else:
35
+ dataset_dicts = get_detection_dataset_dicts(
36
+ cfg.DATASETS.TRAIN,
37
+ filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
38
+ min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
39
+ if cfg.MODEL.KEYPOINT_ON else 0,
40
+ proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
41
+ )
42
+
43
+ if mapper is None:
44
+ mapper = DatasetMapper(cfg, True)
45
+
46
+ if sampler is not None:
47
+ pass
48
+ elif sampler_name == "TrainingSampler":
49
+ sampler = TrainingSampler(len(dataset))
50
+ elif sampler_name == "MultiDatasetSampler":
51
+ sampler = MultiDatasetSampler(
52
+ dataset_dicts,
53
+ dataset_ratio=cfg.DATALOADER.DATASET_RATIO,
54
+ )
55
+ else:
56
+ raise ValueError("Unknown training sampler: {}".format(sampler_name))
57
+
58
+ return {
59
+ "dataset": dataset_dicts,
60
+ "sampler": sampler,
61
+ "mapper": mapper,
62
+ "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
63
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
64
+ 'dataset_bs': cfg.DATALOADER.DATASET_BS,
65
+ 'num_datasets': len(cfg.DATASETS.TRAIN)
66
+ }
67
+
68
+
69
+ @configurable(from_config=_custom_train_loader_from_config)
70
+ def build_custom_train_loader(
71
+ dataset, *, mapper, sampler,
72
+ total_batch_size=16,
73
+ num_workers=0,
74
+ num_datasets=1,
75
+ dataset_bs=1
76
+ ):
77
+
78
+ if isinstance(dataset, list):
79
+ dataset = DatasetFromList(dataset, copy=False)
80
+ if mapper is not None:
81
+ dataset = MapDataset(dataset, mapper)
82
+ if sampler is None:
83
+ sampler = TrainingSampler(len(dataset))
84
+ assert isinstance(sampler, torch.utils.data.sampler.Sampler)
85
+
86
+ return build_dataset_batch_data_loader(
87
+ dataset_bs,
88
+ dataset,
89
+ sampler,
90
+ total_batch_size,
91
+ num_datasets=num_datasets,
92
+ num_workers=num_workers,
93
+ )
94
+
95
+
96
+ def build_dataset_batch_data_loader(
97
+ dataset_bs, dataset, sampler, total_batch_size, num_datasets, num_workers=0
98
+ ):
99
+
100
+ world_size = get_world_size()
101
+ assert (
102
+ total_batch_size > 0 and total_batch_size % world_size == 0
103
+ ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
104
+ total_batch_size, world_size
105
+ )
106
+
107
+ data_loader = torch.utils.data.DataLoader(
108
+ dataset,
109
+ sampler=sampler,
110
+ num_workers=num_workers,
111
+ batch_sampler=None,
112
+ collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
113
+ worker_init_fn=worker_init_reset_seed,
114
+ )
115
+
116
+ if num_datasets > 1:
117
+ return MultiDatasets(data_loader, dataset_bs, num_datasets)
118
+ else:
119
+ return SingleDataset(data_loader, dataset_bs)
120
+
121
+
122
+ def get_detection_dataset_dicts_with_source(
123
+ dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None
124
+ ):
125
+ assert len(dataset_names)
126
+ dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
127
+ for dataset_name, dicts in zip(dataset_names, dataset_dicts):
128
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
129
+
130
+ for source_id, (dataset_name, dicts) in \
131
+ enumerate(zip(dataset_names, dataset_dicts)):
132
+ assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
133
+ for d in dicts:
134
+ d['dataset_source'] = source_id
135
+
136
+ if "annotations" in dicts[0]:
137
+ try:
138
+ class_names = MetadataCatalog.get(dataset_name).thing_classes
139
+ check_metadata_consistency("thing_classes", dataset_name)
140
+ print_instances_class_histogram(dicts, class_names)
141
+ except AttributeError: # class names are not available for this dataset
142
+ pass
143
+
144
+ assert proposal_files is None
145
+
146
+ dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
147
+
148
+ has_instances = "annotations" in dataset_dicts[0]
149
+ if filter_empty and has_instances:
150
+ dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
151
+ if min_keypoints > 0 and has_instances:
152
+ dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
153
+
154
+ return dataset_dicts
155
+
156
+
157
+ class MultiDatasetSampler(Sampler):
158
+ def __init__(
159
+ self,
160
+ dataset_dicts,
161
+ dataset_ratio,
162
+ seed: Optional[int] = None,
163
+ ):
164
+ sizes = [0 for _ in range(len(dataset_ratio))]
165
+ for d in dataset_dicts:
166
+ sizes[d['dataset_source']] += 1
167
+ print('dataset sizes', sizes)
168
+ self.sizes = sizes
169
+ assert len(dataset_ratio) == len(sizes), \
170
+ 'length of dataset ratio {} should be equal to number if dataset {}'.format(
171
+ len(dataset_ratio), len(sizes)
172
+ )
173
+ if seed is None:
174
+ seed = comm.shared_random_seed()
175
+ self._seed = int(seed)
176
+ self._rank = comm.get_rank()
177
+ self._world_size = comm.get_world_size()
178
+
179
+ self.dataset_ids = torch.tensor(
180
+ [d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
181
+ self.dataset_ratio = dataset_ratio
182
+
183
+ dataset_weight = [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) \
184
+ for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
185
+ dataset_weight = torch.cat(dataset_weight)
186
+
187
+ self.weights = dataset_weight
188
+ self.sample_epoch_size = len(self.weights)
189
+
190
+ def __iter__(self):
191
+ start = self._rank
192
+ yield from itertools.islice(
193
+ self._infinite_indices(), start, None, self._world_size)
194
+
195
+ def _infinite_indices(self):
196
+ g = torch.Generator()
197
+ g.manual_seed(self._seed)
198
+ while True:
199
+ if len(self.dataset_ratio) > 1:
200
+ # multiple datasets
201
+ ids = torch.multinomial(
202
+ self.weights, self.sample_epoch_size, generator=g,
203
+ replacement=True)
204
+ nums = [(self.dataset_ids[ids] == i).sum().int().item() \
205
+ for i in range(len(self.sizes))]
206
+ yield from ids
207
+ else:
208
+ # single dataset
209
+ yield from torch.randperm(self.sizes[0], generator=g).tolist()
210
+
211
+
212
+ class SingleDataset(torch.utils.data.IterableDataset):
213
+ def __init__(self, dataset, batch_sizes):
214
+ self.dataset = dataset
215
+ self.batch_sizes = batch_sizes
216
+ self._buckets = [[] for _ in range(2)]
217
+
218
+ def __iter__(self):
219
+ for d in self.dataset:
220
+ w, h = d["width"], d["height"]
221
+ aspect_ratio_bucket_id = 0 if w > h else 1
222
+ bucket_id = aspect_ratio_bucket_id
223
+ bucket = self._buckets[bucket_id]
224
+ bucket.append(d)
225
+ if len(bucket) == self.batch_sizes:
226
+ yield bucket[:]
227
+ del bucket[:]
228
+
229
+
230
+ class MultiDatasets(torch.utils.data.IterableDataset):
231
+ def __init__(self, dataset, batch_sizes, num_datasets):
232
+ self.dataset = dataset
233
+ self.batch_sizes = batch_sizes
234
+ self._buckets = [[] for _ in range(2 * num_datasets)]
235
+ self.iter_idx = 0
236
+ self.num_datasets = num_datasets
237
+
238
+ def __iter__(self):
239
+ for d in self.dataset:
240
+ w, h = d["width"], d["height"]
241
+ aspect_ratio_bucket_id = 0 if w > h else 1
242
+ bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
243
+ bucket = self._buckets[bucket_id]
244
+ if len(bucket) < self.batch_sizes:
245
+ bucket.append(d)
246
+ selected_dataset = self.iter_idx % self.num_datasets
247
+ if len(bucket) == self.batch_sizes and selected_dataset == d['dataset_source']:
248
+ self.iter_idx += 1
249
+ yield bucket[:]
250
+ del bucket[:]
models/grit_src/grit/data/custom_dataset_mapper.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # Modified by Jialian Wu from https://github.com/facebookresearch/Detic/blob/main/detic/data/custom_dataset_mapper.py
3
+ import copy
4
+ import numpy as np
5
+ import torch
6
+
7
+ from detectron2.config import configurable
8
+
9
+ from detectron2.data import detection_utils as utils
10
+ from detectron2.data import transforms as T
11
+ from detectron2.data.dataset_mapper import DatasetMapper
12
+ from .custom_build_augmentation import build_custom_augmentation
13
+ from itertools import compress
14
+ import logging
15
+
16
+ __all__ = ["CustomDatasetMapper", "ObjDescription"]
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class CustomDatasetMapper(DatasetMapper):
21
+ @configurable
22
+ def __init__(self, is_train: bool,
23
+ dataset_augs=[],
24
+ **kwargs):
25
+ if is_train:
26
+ self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs]
27
+ super().__init__(is_train, **kwargs)
28
+
29
+ @classmethod
30
+ def from_config(cls, cfg, is_train: bool = True):
31
+ ret = super().from_config(cfg, is_train)
32
+ if is_train:
33
+ if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
34
+ dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE
35
+ dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE
36
+ ret['dataset_augs'] = [
37
+ build_custom_augmentation(cfg, True, scale, size) \
38
+ for scale, size in zip(dataset_scales, dataset_sizes)]
39
+ else:
40
+ assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge'
41
+ min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES
42
+ max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES
43
+ ret['dataset_augs'] = [
44
+ build_custom_augmentation(
45
+ cfg, True, min_size=mi, max_size=ma) \
46
+ for mi, ma in zip(min_sizes, max_sizes)]
47
+ else:
48
+ ret['dataset_augs'] = []
49
+
50
+ return ret
51
+
52
+ def __call__(self, dataset_dict):
53
+ dataset_dict_out = self.prepare_data(dataset_dict)
54
+
55
+ # When augmented image is too small, do re-augmentation
56
+ retry = 0
57
+ while (dataset_dict_out["image"].shape[1] < 32 or dataset_dict_out["image"].shape[2] < 32):
58
+ retry += 1
59
+ if retry == 100:
60
+ logger.info('Retry 100 times for augmentation. Make sure the image size is not too small.')
61
+ logger.info('Find image information below')
62
+ logger.info(dataset_dict)
63
+ dataset_dict_out = self.prepare_data(dataset_dict)
64
+
65
+ return dataset_dict_out
66
+
67
+ def prepare_data(self, dataset_dict_in):
68
+ dataset_dict = copy.deepcopy(dataset_dict_in)
69
+ if 'file_name' in dataset_dict:
70
+ ori_image = utils.read_image(
71
+ dataset_dict["file_name"], format=self.image_format)
72
+ else:
73
+ ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]]
74
+ ori_image = utils._apply_exif_orientation(ori_image)
75
+ ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format)
76
+ utils.check_image_size(dataset_dict, ori_image)
77
+
78
+ aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=None)
79
+ if self.is_train:
80
+ transforms = \
81
+ self.dataset_augs[dataset_dict['dataset_source']](aug_input)
82
+ else:
83
+ transforms = self.augmentations(aug_input)
84
+ image, sem_seg_gt = aug_input.image, aug_input.sem_seg
85
+
86
+ image_shape = image.shape[:2]
87
+ dataset_dict["image"] = torch.as_tensor(
88
+ np.ascontiguousarray(image.transpose(2, 0, 1)))
89
+
90
+ if not self.is_train:
91
+ # USER: Modify this if you want to keep them for some reason.
92
+ dataset_dict.pop("annotations", None)
93
+ return dataset_dict
94
+
95
+ if "annotations" in dataset_dict:
96
+ if len(dataset_dict["annotations"]) > 0:
97
+ object_descriptions = [an['object_description'] for an in dataset_dict["annotations"]]
98
+ else:
99
+ object_descriptions = []
100
+ # USER: Modify this if you want to keep them for some reason.
101
+ for anno in dataset_dict["annotations"]:
102
+ if not self.use_instance_mask:
103
+ anno.pop("segmentation", None)
104
+ if not self.use_keypoint:
105
+ anno.pop("keypoints", None)
106
+
107
+ all_annos = [
108
+ (utils.transform_instance_annotations(
109
+ obj, transforms, image_shape,
110
+ keypoint_hflip_indices=self.keypoint_hflip_indices,
111
+ ), obj.get("iscrowd", 0))
112
+ for obj in dataset_dict.pop("annotations")
113
+ ]
114
+ annos = [ann[0] for ann in all_annos if ann[1] == 0]
115
+ instances = utils.annotations_to_instances(
116
+ annos, image_shape, mask_format=self.instance_mask_format
117
+ )
118
+
119
+ instances.gt_object_descriptions = ObjDescription(object_descriptions)
120
+
121
+ del all_annos
122
+ if self.recompute_boxes:
123
+ instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
124
+ dataset_dict["instances"] = utils.filter_empty_instances(instances)
125
+
126
+ return dataset_dict
127
+
128
+
129
+ class ObjDescription:
130
+ def __init__(self, object_descriptions):
131
+ self.data = object_descriptions
132
+
133
+ def __getitem__(self, item):
134
+ assert type(item) == torch.Tensor
135
+ assert item.dim() == 1
136
+ if len(item) > 0:
137
+ assert item.dtype == torch.int64 or item.dtype == torch.bool
138
+ if item.dtype == torch.int64:
139
+ return ObjDescription([self.data[x.item()] for x in item])
140
+ elif item.dtype == torch.bool:
141
+ return ObjDescription(list(compress(self.data, item)))
142
+
143
+ return ObjDescription(list(compress(self.data, item)))
144
+
145
+ def __len__(self):
146
+ return len(self.data)
147
+
148
+ def __repr__(self):
149
+ return "ObjDescription({})".format(self.data)
models/grit_src/grit/data/datasets/__pycache__/grit_coco.cpython-38.pyc ADDED
Binary file (3.94 kB). View file
 
models/grit_src/grit/data/datasets/__pycache__/object365.cpython-38.pyc ADDED
Binary file (3.7 kB). View file
 
models/grit_src/grit/data/datasets/__pycache__/vg.cpython-38.pyc ADDED
Binary file (3.28 kB). View file
 
models/grit_src/grit/data/datasets/grit_coco.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from fvcore.common.timer import Timer
4
+ from detectron2.structures import BoxMode
5
+ from fvcore.common.file_io import PathManager
6
+ from detectron2.data import DatasetCatalog, MetadataCatalog
7
+ from lvis import LVIS
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ __all__ = ["load_GRiTcoco_json", "register_GRiTcoco_instances"]
12
+
13
+
14
+ def register_GRiTcoco_instances(name, metadata, json_file, image_root):
15
+ """
16
+ """
17
+ DatasetCatalog.register(name, lambda: load_GRiTcoco_json(
18
+ json_file, image_root, name))
19
+ MetadataCatalog.get(name).set(
20
+ json_file=json_file, image_root=image_root,
21
+ evaluator_type="coco", **metadata
22
+ )
23
+
24
+
25
+ def get_GRiTcoco_meta():
26
+ categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
27
+ categories = sorted(categories, key=lambda x: x["id"])
28
+ thing_classes = [k["name"] for k in categories]
29
+ meta = {"thing_classes": thing_classes}
30
+ return meta
31
+
32
+
33
+ def load_GRiTcoco_json(json_file, image_root, dataset_name=None):
34
+ '''
35
+ Load COCO class name text for object description for GRiT
36
+ '''
37
+
38
+ json_file = PathManager.get_local_path(json_file)
39
+
40
+ timer = Timer()
41
+ lvis_api = LVIS(json_file)
42
+ if timer.seconds() > 1:
43
+ logger.info("Loading {} takes {:.2f} seconds.".format(
44
+ json_file, timer.seconds()))
45
+
46
+ class_names = {}
47
+ sort_cat = sorted(lvis_api.dataset['categories'], key=lambda x: x['id'])
48
+ for x in sort_cat:
49
+ class_names[x['id']] = x['name']
50
+
51
+ img_ids = sorted(lvis_api.imgs.keys())
52
+ imgs = lvis_api.load_imgs(img_ids)
53
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
54
+
55
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
56
+ assert len(set(ann_ids)) == len(ann_ids), \
57
+ "Annotation ids in '{}' are not unique".format(json_file)
58
+
59
+ imgs_anns = list(zip(imgs, anns))
60
+ logger.info("Loaded {} images in the LVIS v1 format from {}".format(
61
+ len(imgs_anns), json_file))
62
+
63
+ dataset_dicts = []
64
+
65
+ for (img_dict, anno_dict_list) in imgs_anns:
66
+ record = {}
67
+ if "file_name" in img_dict:
68
+ file_name = img_dict["file_name"]
69
+ record["file_name"] = os.path.join(image_root, file_name)
70
+
71
+ record["height"] = int(img_dict["height"])
72
+ record["width"] = int(img_dict["width"])
73
+ image_id = record["image_id"] = img_dict["id"]
74
+
75
+ objs = []
76
+ for anno in anno_dict_list:
77
+ assert anno["image_id"] == image_id
78
+ if anno.get('iscrowd', 0) > 0:
79
+ continue
80
+ obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
81
+ obj["category_id"] = 0
82
+ obj["object_description"] = class_names[anno['category_id']]
83
+ if 'segmentation' in anno:
84
+ segm = anno["segmentation"]
85
+ valid_segm = [poly for poly in segm \
86
+ if len(poly) % 2 == 0 and len(poly) >= 6]
87
+ if not len(segm) == len(valid_segm):
88
+ print('Annotation contains an invalid polygon with < 3 points')
89
+ assert len(segm) > 0
90
+ obj["segmentation"] = segm
91
+ objs.append(obj)
92
+ record["annotations"] = objs
93
+ if len(record["annotations"]) == 0:
94
+ continue
95
+ record["task"] = "ObjectDet"
96
+ dataset_dicts.append(record)
97
+
98
+ return dataset_dicts
99
+
100
+
101
+ _CUSTOM_SPLITS_LVIS = {
102
+ "GRiT_coco2017_train": ("coco/train2017/", "coco/annotations/instances_train2017.json"),
103
+ }
104
+
105
+
106
+ for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
107
+ register_GRiTcoco_instances(
108
+ key,
109
+ get_GRiTcoco_meta(),
110
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
111
+ os.path.join("datasets", image_root),
112
+ )
models/grit_src/grit/data/datasets/object365.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from fvcore.common.timer import Timer
4
+ from detectron2.structures import BoxMode
5
+ from fvcore.common.file_io import PathManager
6
+ from detectron2.data import DatasetCatalog, MetadataCatalog
7
+ from lvis import LVIS
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ __all__ = ["load_o365_json", "register_o365_instances"]
12
+
13
+
14
+ def register_o365_instances(name, metadata, json_file, image_root):
15
+ DatasetCatalog.register(name, lambda: load_o365_json(
16
+ json_file, image_root, name))
17
+ MetadataCatalog.get(name).set(
18
+ json_file=json_file, image_root=image_root,
19
+ evaluator_type="lvis", **metadata
20
+ )
21
+
22
+
23
+ def get_o365_meta():
24
+ categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
25
+ o365_categories = sorted(categories, key=lambda x: x["id"])
26
+ thing_classes = [k["name"] for k in o365_categories]
27
+ meta = {"thing_classes": thing_classes}
28
+ return meta
29
+
30
+
31
+ def load_o365_json(json_file, image_root, dataset_name=None):
32
+ '''
33
+ Load Object365 class name text for object description for GRiT
34
+ '''
35
+
36
+ json_file = PathManager.get_local_path(json_file)
37
+
38
+ timer = Timer()
39
+ lvis_api = LVIS(json_file)
40
+ if timer.seconds() > 1:
41
+ logger.info("Loading {} takes {:.2f} seconds.".format(
42
+ json_file, timer.seconds()))
43
+
44
+ class_names = {}
45
+ sort_cat = sorted(lvis_api.dataset['categories'], key=lambda x: x['id'])
46
+ for x in sort_cat:
47
+ if '/' in x['name']:
48
+ text = ''
49
+ for xx in x['name'].split('/'):
50
+ text += xx
51
+ text += ' '
52
+ text = text[:-1]
53
+ else:
54
+ text = x['name']
55
+ class_names[x['id']] = text
56
+
57
+ img_ids = sorted(lvis_api.imgs.keys())
58
+ imgs = lvis_api.load_imgs(img_ids)
59
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
60
+
61
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
62
+ assert len(set(ann_ids)) == len(ann_ids), \
63
+ "Annotation ids in '{}' are not unique".format(json_file)
64
+
65
+ imgs_anns = list(zip(imgs, anns))
66
+ logger.info("Loaded {} images in the LVIS v1 format from {}".format(
67
+ len(imgs_anns), json_file))
68
+
69
+ dataset_dicts = []
70
+
71
+ for (img_dict, anno_dict_list) in imgs_anns:
72
+ record = {}
73
+ if "file_name" in img_dict:
74
+ file_name = img_dict["file_name"]
75
+ record["file_name"] = os.path.join(image_root, file_name)
76
+
77
+ record["height"] = int(img_dict["height"])
78
+ record["width"] = int(img_dict["width"])
79
+ image_id = record["image_id"] = img_dict["id"]
80
+
81
+ objs = []
82
+ for anno in anno_dict_list:
83
+ assert anno["image_id"] == image_id
84
+ if anno.get('iscrowd', 0) > 0:
85
+ continue
86
+ obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
87
+ obj["category_id"] = 0
88
+ obj["object_description"] = class_names[anno['category_id']]
89
+
90
+ objs.append(obj)
91
+ record["annotations"] = objs
92
+ if len(record["annotations"]) == 0:
93
+ continue
94
+ record["task"] = "ObjectDet"
95
+ dataset_dicts.append(record)
96
+
97
+ return dataset_dicts
98
+
99
+
100
+ _CUSTOM_SPLITS_LVIS = {
101
+ "object365_train": ("object365/images/train/", "object365/annotations/train_v1.json"),
102
+ }
103
+
104
+
105
+ for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
106
+ register_o365_instances(
107
+ key,
108
+ get_o365_meta(),
109
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
110
+ os.path.join("datasets", image_root),
111
+ )
models/grit_src/grit/data/datasets/vg.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from fvcore.common.timer import Timer
4
+ from detectron2.structures import BoxMode
5
+ from fvcore.common.file_io import PathManager
6
+ from detectron2.data import DatasetCatalog, MetadataCatalog
7
+ from lvis import LVIS
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ __all__ = ["load_vg_json", "register_vg_instances"]
12
+
13
+
14
+ def register_vg_instances(name, metadata, json_file, image_root):
15
+ """
16
+ """
17
+ DatasetCatalog.register(name, lambda: load_vg_json(
18
+ json_file, image_root, name))
19
+ MetadataCatalog.get(name).set(
20
+ json_file=json_file, image_root=image_root,
21
+ evaluator_type="vg", **metadata
22
+ )
23
+
24
+
25
+ def get_vg_meta():
26
+ categories = [{'supercategory': 'object', 'id': 1, 'name': 'object'}]
27
+ vg_categories = sorted(categories, key=lambda x: x["id"])
28
+ thing_classes = [k["name"] for k in vg_categories]
29
+ meta = {"thing_classes": thing_classes}
30
+ return meta
31
+
32
+
33
+ def load_vg_json(json_file, image_root, dataset_name=None):
34
+
35
+ json_file = PathManager.get_local_path(json_file)
36
+
37
+ timer = Timer()
38
+ lvis_api = LVIS(json_file)
39
+ if timer.seconds() > 1:
40
+ logger.info("Loading {} takes {:.2f} seconds.".format(
41
+ json_file, timer.seconds()))
42
+
43
+ img_ids = sorted(lvis_api.imgs.keys())
44
+ imgs = lvis_api.load_imgs(img_ids)
45
+ anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
46
+
47
+ ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
48
+ assert len(set(ann_ids)) == len(ann_ids), \
49
+ "Annotation ids in '{}' are not unique".format(json_file)
50
+
51
+ imgs_anns = list(zip(imgs, anns))
52
+ logger.info("Loaded {} images in the LVIS v1 format from {}".format(
53
+ len(imgs_anns), json_file))
54
+
55
+ dataset_dicts = []
56
+
57
+ for (img_dict, anno_dict_list) in imgs_anns:
58
+ record = {}
59
+ if "file_name" in img_dict:
60
+ file_name = img_dict["file_name"]
61
+ record["file_name"] = os.path.join(image_root, file_name)
62
+
63
+ record["height"] = int(img_dict["height"])
64
+ record["width"] = int(img_dict["width"])
65
+ image_id = record["image_id"] = img_dict["id"]
66
+
67
+ objs = []
68
+ for anno in anno_dict_list:
69
+ assert anno["image_id"] == image_id
70
+ if anno.get('iscrowd', 0) > 0:
71
+ continue
72
+ obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
73
+ obj["category_id"] = 0
74
+ obj["object_description"] = anno["caption"]
75
+
76
+ objs.append(obj)
77
+ record["annotations"] = objs
78
+ if len(record["annotations"]) == 0:
79
+ continue
80
+ record["task"] = "DenseCap"
81
+ dataset_dicts.append(record)
82
+
83
+ return dataset_dicts
84
+
85
+
86
+ _CUSTOM_SPLITS_LVIS = {
87
+ "vg_train": ("vg/images", "vg/annotations/train.json"),
88
+ "vg_test": ("vg/images", "vg/annotations/test.json"),
89
+ }
90
+
91
+
92
+ for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
93
+ register_vg_instances(
94
+ key,
95
+ get_vg_meta(),
96
+ os.path.join("datasets", json_file) if "://" not in json_file else json_file,
97
+ os.path.join("datasets", image_root),
98
+ )
models/grit_src/grit/data/transforms/__pycache__/custom_augmentation_impl.cpython-38.pyc ADDED
Binary file (1.73 kB). View file
 
models/grit_src/grit/data/transforms/__pycache__/custom_transform.cpython-38.pyc ADDED
Binary file (3.89 kB). View file
 
models/grit_src/grit/data/transforms/custom_augmentation_impl.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
4
+ # Modified by Xingyi Zhou
5
+ # The original code is under Apache-2.0 License
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from detectron2.data.transforms.augmentation import Augmentation
10
+ from .custom_transform import EfficientDetResizeCropTransform
11
+
12
+ __all__ = [
13
+ "EfficientDetResizeCrop",
14
+ ]
15
+
16
+
17
+ class EfficientDetResizeCrop(Augmentation):
18
+ """
19
+ Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
20
+ If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
21
+ """
22
+
23
+ def __init__(
24
+ self, size, scale, interp=Image.BILINEAR
25
+ ):
26
+ """
27
+ """
28
+ super().__init__()
29
+ self.target_size = (size, size)
30
+ self.scale = scale
31
+ self.interp = interp
32
+
33
+ def get_transform(self, img):
34
+ # Select a random scale factor.
35
+ scale_factor = np.random.uniform(*self.scale)
36
+ scaled_target_height = scale_factor * self.target_size[0]
37
+ scaled_target_width = scale_factor * self.target_size[1]
38
+ # Recompute the accurate scale_factor using rounded scaled image size.
39
+ width, height = img.shape[1], img.shape[0]
40
+ img_scale_y = scaled_target_height / height
41
+ img_scale_x = scaled_target_width / width
42
+ img_scale = min(img_scale_y, img_scale_x)
43
+
44
+ # Select non-zero random offset (x, y) if scaled image is larger than target size
45
+ scaled_h = int(height * img_scale)
46
+ scaled_w = int(width * img_scale)
47
+ offset_y = scaled_h - self.target_size[0]
48
+ offset_x = scaled_w - self.target_size[1]
49
+ offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1))
50
+ offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1))
51
+ return EfficientDetResizeCropTransform(
52
+ scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp)
models/grit_src/grit/data/transforms/custom_transform.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
+ # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
4
+ # Modified by Xingyi Zhou
5
+ # The original code is under Apache-2.0 License
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from fvcore.transforms.transform import (
10
+ CropTransform,
11
+ HFlipTransform,
12
+ NoOpTransform,
13
+ Transform,
14
+ TransformList,
15
+ )
16
+ from PIL import Image
17
+
18
+ try:
19
+ import cv2 # noqa
20
+ except ImportError:
21
+ # OpenCV is an optional dependency at the moment
22
+ pass
23
+
24
+ __all__ = [
25
+ "EfficientDetResizeCropTransform",
26
+ ]
27
+
28
+
29
+ class EfficientDetResizeCropTransform(Transform):
30
+ """
31
+ """
32
+
33
+ def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \
34
+ target_size, interp=None):
35
+ """
36
+ Args:
37
+ h, w (int): original image size
38
+ new_h, new_w (int): new image size
39
+ interp: PIL interpolation methods, defaults to bilinear.
40
+ """
41
+ # TODO decide on PIL vs opencv
42
+ super().__init__()
43
+ if interp is None:
44
+ interp = Image.BILINEAR
45
+ self._set_attributes(locals())
46
+
47
+ def apply_image(self, img, interp=None):
48
+ assert len(img.shape) <= 4
49
+
50
+ if img.dtype == np.uint8:
51
+ pil_image = Image.fromarray(img)
52
+ interp_method = interp if interp is not None else self.interp
53
+ pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method)
54
+ ret = np.asarray(pil_image)
55
+ right = min(self.scaled_w, self.offset_x + self.target_size[1])
56
+ lower = min(self.scaled_h, self.offset_y + self.target_size[0])
57
+ if len(ret.shape) <= 3:
58
+ ret = ret[self.offset_y: lower, self.offset_x: right]
59
+ else:
60
+ ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
61
+ else:
62
+ # PIL only supports uint8
63
+ img = torch.from_numpy(img)
64
+ shape = list(img.shape)
65
+ shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
66
+ img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
67
+ _PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
68
+ mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
69
+ img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False)
70
+ shape[:2] = (self.scaled_h, self.scaled_w)
71
+ ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
72
+ right = min(self.scaled_w, self.offset_x + self.target_size[1])
73
+ lower = min(self.scaled_h, self.offset_y + self.target_size[0])
74
+ if len(ret.shape) <= 3:
75
+ ret = ret[self.offset_y: lower, self.offset_x: right]
76
+ else:
77
+ ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
78
+ return ret
79
+
80
+
81
+ def apply_coords(self, coords):
82
+ coords[:, 0] = coords[:, 0] * self.img_scale
83
+ coords[:, 1] = coords[:, 1] * self.img_scale
84
+ coords[:, 0] -= self.offset_x
85
+ coords[:, 1] -= self.offset_y
86
+ return coords
87
+
88
+
89
+ def apply_segmentation(self, segmentation):
90
+ segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
91
+ return segmentation
92
+
93
+
94
+ def inverse(self):
95
+ raise NotImplementedError
96
+
97
+
98
+ def inverse_apply_coords(self, coords):
99
+ coords[:, 0] += self.offset_x
100
+ coords[:, 1] += self.offset_y
101
+ coords[:, 0] = coords[:, 0] / self.img_scale
102
+ coords[:, 1] = coords[:, 1] / self.img_scale
103
+ return coords
104
+
105
+
106
+ def inverse_apply_box(self, box: np.ndarray) -> np.ndarray:
107
+ """
108
+ """
109
+ idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
110
+ coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
111
+ coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2))
112
+ minxy = coords.min(axis=1)
113
+ maxxy = coords.max(axis=1)
114
+ trans_boxes = np.concatenate((minxy, maxxy), axis=1)
115
+ return trans_boxes
models/grit_src/grit/evaluation/eval.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import json
3
+ import os
4
+ from detectron2.structures import Boxes, BoxMode, pairwise_iou
5
+ from detectron2.utils.file_io import PathManager
6
+ import numpy as np
7
+ import pycocotools.mask as mask_util
8
+ from detectron2.evaluation.coco_evaluation import COCOEvaluator
9
+ from detectron2.evaluation.coco_evaluation import _evaluate_predictions_on_coco
10
+
11
+
12
+ class GRiTCOCOEvaluator(COCOEvaluator):
13
+ def process(self, inputs, outputs):
14
+ for input, output in zip(inputs, outputs):
15
+ prediction = {"image_id": input["image_id"]}
16
+
17
+ if "instances" in output:
18
+ instances = output["instances"].to(self._cpu_device)
19
+ prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
20
+
21
+ if len(prediction) > 1:
22
+ self._predictions.append(prediction)
23
+
24
+ def _eval_predictions(self, predictions, img_ids=None):
25
+ self._logger.info("Preparing results for COCO format ...")
26
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
27
+ tasks = self._tasks or self._tasks_from_predictions(coco_results)
28
+
29
+ if self._output_dir:
30
+ file_path = os.path.join(self._output_dir, "coco_instances_results.json")
31
+ self._logger.info("Saving results to {}".format(file_path))
32
+ with PathManager.open(file_path, "w") as f:
33
+ f.write(json.dumps(coco_results))
34
+ f.flush()
35
+
36
+ if not self._do_evaluation:
37
+ self._logger.info("Annotations are not available for evaluation.")
38
+ return
39
+
40
+ self._logger.info(
41
+ "Evaluating predictions with {} COCO API...".format(
42
+ "unofficial" if self._use_fast_impl else "official"
43
+ )
44
+ )
45
+
46
+ coco_results = self.convert_classname_to_id(coco_results)
47
+
48
+ for task in sorted(tasks):
49
+ assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
50
+ coco_eval = (
51
+ _evaluate_predictions_on_coco(
52
+ self._coco_api,
53
+ coco_results,
54
+ task,
55
+ kpt_oks_sigmas=self._kpt_oks_sigmas,
56
+ use_fast_impl=self._use_fast_impl,
57
+ img_ids=img_ids,
58
+ max_dets_per_image=self._max_dets_per_image,
59
+ )
60
+ if len(coco_results) > 0
61
+ else None # cocoapi does not handle empty results very well
62
+ )
63
+
64
+ res = self._derive_coco_results(
65
+ coco_eval, task, class_names=self._metadata.get("thing_classes")
66
+ )
67
+ self._results[task] = res
68
+
69
+ def convert_classname_to_id(self, results):
70
+ outputs = []
71
+ class_name_to_id = {}
72
+ categories = sorted(self._coco_api.dataset['categories'], key=lambda x: x['id'])
73
+
74
+ for cat in categories:
75
+ class_name_to_id[cat['name']] = cat['id']
76
+
77
+ for pred in results:
78
+ if pred['object_descriptions'] in class_name_to_id:
79
+ pred['category_id'] = class_name_to_id[pred['object_descriptions']]
80
+ del pred['object_descriptions']
81
+ outputs.append(pred)
82
+
83
+ return outputs
84
+
85
+
86
+ class GRiTVGEvaluator(COCOEvaluator):
87
+ def process(self, inputs, outputs):
88
+ for input, output in zip(inputs, outputs):
89
+ assert input["image_id"] == int(input['file_name'].split('/')[-1].split('.')[0])
90
+ prediction = {"image_id": input["image_id"]}
91
+
92
+ if "instances" in output:
93
+ instances = output["instances"].to(self._cpu_device)
94
+ prediction["instances"] = instances_to_coco_json(instances, input["image_id"], output_logits=True)
95
+ h = input['height']
96
+ w = input['width']
97
+ scale = 720.0 / max(h, w)
98
+ scaled_inst = []
99
+ for inst in prediction["instances"]:
100
+ inst['bbox'][0] = inst['bbox'][0] * scale
101
+ inst['bbox'][1] = inst['bbox'][1] * scale
102
+ inst['bbox'][2] = inst['bbox'][2] * scale
103
+ inst['bbox'][3] = inst['bbox'][3] * scale
104
+ scaled_inst.append(inst)
105
+ if len(scaled_inst) > 0:
106
+ prediction["instances"] = scaled_inst
107
+ if len(prediction) > 1:
108
+ self._predictions.append(prediction)
109
+
110
+ def _eval_predictions(self, predictions, img_ids=None):
111
+ '''
112
+ This is only for saving the results to json file
113
+ '''
114
+ self._logger.info("Preparing results for COCO format ...")
115
+ coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
116
+
117
+ if self._output_dir:
118
+ file_path = os.path.join(self._output_dir, "vg_instances_results.json")
119
+ self._logger.info("Saving results to {}".format(file_path))
120
+ with PathManager.open(file_path, "w") as f:
121
+ f.write(json.dumps(coco_results))
122
+ f.flush()
123
+
124
+
125
+ def instances_to_coco_json(instances, img_id, output_logits=False):
126
+ """
127
+ Add object_descriptions and logit (if applicable) to
128
+ detectron2's instances_to_coco_json
129
+ """
130
+ num_instance = len(instances)
131
+ if num_instance == 0:
132
+ return []
133
+
134
+ boxes = instances.pred_boxes.tensor.numpy()
135
+ boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
136
+ boxes = boxes.tolist()
137
+ scores = instances.scores.tolist()
138
+ classes = instances.pred_classes.tolist()
139
+ object_descriptions = instances.pred_object_descriptions.data
140
+ if output_logits:
141
+ logits = instances.logits.tolist()
142
+
143
+ results = []
144
+ for k in range(num_instance):
145
+ result = {
146
+ "image_id": img_id,
147
+ "category_id": classes[k],
148
+ "bbox": boxes[k],
149
+ "score": scores[k],
150
+ 'object_descriptions': object_descriptions[k],
151
+ }
152
+ if output_logits:
153
+ result["logit"] = logits[k]
154
+
155
+ results.append(result)
156
+ return results
models/grit_src/grit/modeling/__pycache__/soft_nms.cpython-38.pyc ADDED
Binary file (5.99 kB). View file
 
models/grit_src/grit/modeling/backbone/__pycache__/utils.cpython-38.pyc ADDED
Binary file (6.12 kB). View file
 
models/grit_src/grit/modeling/backbone/__pycache__/vit.cpython-38.pyc ADDED
Binary file (15.6 kB). View file
 
models/grit_src/grit/modeling/backbone/utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ # This code is from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/utils.py
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ __all__ = [
9
+ "window_partition",
10
+ "window_unpartition",
11
+ "add_decomposed_rel_pos",
12
+ "get_abs_pos",
13
+ "PatchEmbed",
14
+ ]
15
+
16
+ def window_partition(x, window_size):
17
+ """
18
+ Partition into non-overlapping windows with padding if needed.
19
+ Args:
20
+ x (tensor): input tokens with [B, H, W, C].
21
+ window_size (int): window size.
22
+
23
+ Returns:
24
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
25
+ (Hp, Wp): padded height and width before partition
26
+ """
27
+ B, H, W, C = x.shape
28
+
29
+ pad_h = (window_size - H % window_size) % window_size
30
+ pad_w = (window_size - W % window_size) % window_size
31
+ if pad_h > 0 or pad_w > 0:
32
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
33
+ Hp, Wp = H + pad_h, W + pad_w
34
+
35
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
36
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37
+ return windows, (Hp, Wp)
38
+
39
+
40
+ def window_unpartition(windows, window_size, pad_hw, hw):
41
+ """
42
+ Window unpartition into original sequences and removing padding.
43
+ Args:
44
+ x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
45
+ window_size (int): window size.
46
+ pad_hw (Tuple): padded height and width (Hp, Wp).
47
+ hw (Tuple): original height and width (H, W) before padding.
48
+
49
+ Returns:
50
+ x: unpartitioned sequences with [B, H, W, C].
51
+ """
52
+ Hp, Wp = pad_hw
53
+ H, W = hw
54
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55
+ x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
56
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
57
+
58
+ if Hp > H or Wp > W:
59
+ x = x[:, :H, :W, :].contiguous()
60
+ return x
61
+
62
+
63
+ def get_rel_pos(q_size, k_size, rel_pos):
64
+ """
65
+ Get relative positional embeddings according to the relative positions of
66
+ query and key sizes.
67
+ Args:
68
+ q_size (int): size of query q.
69
+ k_size (int): size of key k.
70
+ rel_pos (Tensor): relative position embeddings (L, C).
71
+
72
+ Returns:
73
+ Extracted positional embeddings according to relative positions.
74
+ """
75
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
76
+ # Interpolate rel pos if needed.
77
+ if rel_pos.shape[0] != max_rel_dist:
78
+ # Interpolate rel pos.
79
+ rel_pos_resized = F.interpolate(
80
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
81
+ size=max_rel_dist,
82
+ mode="linear",
83
+ )
84
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
85
+ else:
86
+ rel_pos_resized = rel_pos
87
+
88
+ # Scale the coords with short length if shapes for q and k are different.
89
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
90
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
91
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
92
+
93
+ return rel_pos_resized[relative_coords.long()]
94
+
95
+
96
+ def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
97
+ """
98
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
99
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
100
+ Args:
101
+ attn (Tensor): attention map.
102
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
103
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
104
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
105
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
106
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
107
+
108
+ Returns:
109
+ attn (Tensor): attention map with added relative positional embeddings.
110
+ """
111
+ q_h, q_w = q_size
112
+ k_h, k_w = k_size
113
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
114
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
115
+
116
+ B, _, dim = q.shape
117
+ r_q = q.reshape(B, q_h, q_w, dim)
118
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
119
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
120
+
121
+ attn = (
122
+ attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
123
+ ).view(B, q_h * q_w, k_h * k_w)
124
+
125
+ return attn
126
+
127
+
128
+ def get_abs_pos(abs_pos, has_cls_token, hw):
129
+ """
130
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
131
+ dimension for the original embeddings.
132
+ Args:
133
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
134
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
135
+ hw (Tuple): size of input image tokens.
136
+
137
+ Returns:
138
+ Absolute positional embeddings after processing with shape (1, H, W, C)
139
+ """
140
+ h, w = hw
141
+ if has_cls_token:
142
+ abs_pos = abs_pos[:, 1:]
143
+ xy_num = abs_pos.shape[1]
144
+ size = int(math.sqrt(xy_num))
145
+ assert size * size == xy_num
146
+
147
+ if size != h or size != w:
148
+ new_abs_pos = F.interpolate(
149
+ abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
150
+ size=(h, w),
151
+ mode="bicubic",
152
+ align_corners=False,
153
+ )
154
+
155
+ return new_abs_pos.permute(0, 2, 3, 1)
156
+ else:
157
+ return abs_pos.reshape(1, h, w, -1)
158
+
159
+
160
+ class PatchEmbed(nn.Module):
161
+ """
162
+ Image to Patch Embedding.
163
+ """
164
+
165
+ def __init__(
166
+ self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
167
+ ):
168
+ """
169
+ Args:
170
+ kernel_size (Tuple): kernel size of the projection layer.
171
+ stride (Tuple): stride of the projection layer.
172
+ padding (Tuple): padding size of the projection layer.
173
+ in_chans (int): Number of input image channels.
174
+ embed_dim (int): embed_dim (int): Patch embedding dimension.
175
+ """
176
+ super().__init__()
177
+
178
+ self.proj = nn.Conv2d(
179
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
180
+ )
181
+
182
+ def forward(self, x):
183
+ x = self.proj(x)
184
+ # B C H W -> B H W C
185
+ x = x.permute(0, 2, 3, 1)
186
+ return x
models/grit_src/grit/modeling/backbone/vit.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified by Jialian Wu from https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
2
+ import logging
3
+ import math
4
+ import fvcore.nn.weight_init as weight_init
5
+ import torch
6
+ import torch.nn as nn
7
+ from functools import partial
8
+
9
+ from detectron2.layers import CNNBlockBase, Conv2d, get_norm
10
+ from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
11
+ from detectron2.layers import ShapeSpec
12
+ from centernet.modeling.backbone.fpn_p5 import LastLevelP6P7_P5
13
+
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import DropPath, Mlp, trunc_normal_
16
+
17
+ from detectron2.modeling.backbone.backbone import Backbone
18
+ from .utils import (
19
+ PatchEmbed,
20
+ add_decomposed_rel_pos,
21
+ get_abs_pos,
22
+ window_partition,
23
+ window_unpartition,
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ __all__ = ["ViT"]
30
+
31
+
32
+ class Attention(nn.Module):
33
+ """Multi-head Attention block with relative position embeddings."""
34
+
35
+ def __init__(
36
+ self,
37
+ dim,
38
+ num_heads=8,
39
+ qkv_bias=True,
40
+ use_rel_pos=False,
41
+ rel_pos_zero_init=True,
42
+ input_size=None,
43
+ ):
44
+ """
45
+ Args:
46
+ dim (int): Number of input channels.
47
+ num_heads (int): Number of attention heads.
48
+ qkv_bias (bool: If True, add a learnable bias to query, key, value.
49
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
50
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
51
+ input_size (int or None): Input resolution for calculating the relative positional
52
+ parameter size.
53
+ """
54
+ super().__init__()
55
+ self.num_heads = num_heads
56
+ head_dim = dim // num_heads
57
+ self.scale = head_dim**-0.5
58
+
59
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
60
+ self.proj = nn.Linear(dim, dim)
61
+
62
+ self.use_rel_pos = use_rel_pos
63
+ if self.use_rel_pos:
64
+ # initialize relative positional embeddings
65
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
66
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
67
+
68
+ if not rel_pos_zero_init:
69
+ trunc_normal_(self.rel_pos_h, std=0.02)
70
+ trunc_normal_(self.rel_pos_w, std=0.02)
71
+
72
+ def forward(self, x):
73
+ B, H, W, _ = x.shape
74
+ # qkv with shape (3, B, nHead, H * W, C)
75
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
76
+ # q, k, v with shape (B * nHead, H * W, C)
77
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
78
+
79
+ attn = (q * self.scale) @ k.transpose(-2, -1)
80
+
81
+ if self.use_rel_pos:
82
+ attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
83
+
84
+ attn = attn.softmax(dim=-1)
85
+ x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
86
+ x = self.proj(x)
87
+
88
+ return x
89
+
90
+
91
+ class ResBottleneckBlock(CNNBlockBase):
92
+ """
93
+ The standard bottleneck residual block without the last activation layer.
94
+ It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ in_channels,
100
+ out_channels,
101
+ bottleneck_channels,
102
+ norm="LN",
103
+ act_layer=nn.GELU,
104
+ ):
105
+ """
106
+ Args:
107
+ in_channels (int): Number of input channels.
108
+ out_channels (int): Number of output channels.
109
+ bottleneck_channels (int): number of output channels for the 3x3
110
+ "bottleneck" conv layers.
111
+ norm (str or callable): normalization for all conv layers.
112
+ See :func:`layers.get_norm` for supported format.
113
+ act_layer (callable): activation for all conv layers.
114
+ """
115
+ super().__init__(in_channels, out_channels, 1)
116
+
117
+ self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
118
+ self.norm1 = get_norm(norm, bottleneck_channels)
119
+ self.act1 = act_layer()
120
+
121
+ self.conv2 = Conv2d(
122
+ bottleneck_channels,
123
+ bottleneck_channels,
124
+ 3,
125
+ padding=1,
126
+ bias=False,
127
+ )
128
+ self.norm2 = get_norm(norm, bottleneck_channels)
129
+ self.act2 = act_layer()
130
+
131
+ self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
132
+ self.norm3 = get_norm(norm, out_channels)
133
+
134
+ for layer in [self.conv1, self.conv2, self.conv3]:
135
+ weight_init.c2_msra_fill(layer)
136
+ for layer in [self.norm1, self.norm2]:
137
+ layer.weight.data.fill_(1.0)
138
+ layer.bias.data.zero_()
139
+ # zero init last norm layer.
140
+ self.norm3.weight.data.zero_()
141
+ self.norm3.bias.data.zero_()
142
+
143
+ def forward(self, x):
144
+ out = x
145
+ for layer in self.children():
146
+ out = layer(out)
147
+
148
+ out = x + out
149
+ return out
150
+
151
+
152
+ class Block(nn.Module):
153
+ """Transformer blocks with support of window attention and residual propagation blocks"""
154
+
155
+ def __init__(
156
+ self,
157
+ dim,
158
+ num_heads,
159
+ mlp_ratio=4.0,
160
+ qkv_bias=True,
161
+ drop_path=0.0,
162
+ norm_layer=nn.LayerNorm,
163
+ act_layer=nn.GELU,
164
+ use_rel_pos=False,
165
+ rel_pos_zero_init=True,
166
+ window_size=0,
167
+ use_residual_block=False,
168
+ input_size=None,
169
+ ):
170
+ """
171
+ Args:
172
+ dim (int): Number of input channels.
173
+ num_heads (int): Number of attention heads in each ViT block.
174
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
175
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
176
+ drop_path (float): Stochastic depth rate.
177
+ norm_layer (nn.Module): Normalization layer.
178
+ act_layer (nn.Module): Activation layer.
179
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
180
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
181
+ window_size (int): Window size for window attention blocks. If it equals 0, then not
182
+ use window attention.
183
+ use_residual_block (bool): If True, use a residual block after the MLP block.
184
+ input_size (int or None): Input resolution for calculating the relative positional
185
+ parameter size.
186
+ """
187
+ super().__init__()
188
+ self.norm1 = norm_layer(dim)
189
+ self.attn = Attention(
190
+ dim,
191
+ num_heads=num_heads,
192
+ qkv_bias=qkv_bias,
193
+ use_rel_pos=use_rel_pos,
194
+ rel_pos_zero_init=rel_pos_zero_init,
195
+ input_size=input_size if window_size == 0 else (window_size, window_size),
196
+ )
197
+
198
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
199
+ self.norm2 = norm_layer(dim)
200
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
201
+
202
+ self.window_size = window_size
203
+
204
+ self.use_residual_block = use_residual_block
205
+ if use_residual_block:
206
+ # Use a residual block with bottleneck channel as dim // 2
207
+ self.residual = ResBottleneckBlock(
208
+ in_channels=dim,
209
+ out_channels=dim,
210
+ bottleneck_channels=dim // 2,
211
+ norm="LN",
212
+ act_layer=act_layer,
213
+ )
214
+
215
+ def forward(self, x):
216
+ shortcut = x
217
+ x = self.norm1(x)
218
+ # Window partition
219
+ if self.window_size > 0:
220
+ H, W = x.shape[1], x.shape[2]
221
+ x, pad_hw = window_partition(x, self.window_size)
222
+
223
+ x = self.attn(x)
224
+ # Reverse window partition
225
+ if self.window_size > 0:
226
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
227
+
228
+ x = shortcut + self.drop_path(x)
229
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
230
+
231
+ if self.use_residual_block:
232
+ x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
233
+
234
+ return x
235
+
236
+
237
+ class ViT(Backbone):
238
+ """
239
+ This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
240
+ "Exploring Plain Vision Transformer Backbones for Object Detection",
241
+ https://arxiv.org/abs/2203.16527
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ img_size=1024,
247
+ patch_size=16,
248
+ in_chans=3,
249
+ embed_dim=768,
250
+ depth=12,
251
+ num_heads=12,
252
+ mlp_ratio=4.0,
253
+ qkv_bias=True,
254
+ drop_path_rate=0.0,
255
+ norm_layer=nn.LayerNorm,
256
+ act_layer=nn.GELU,
257
+ use_abs_pos=True,
258
+ use_rel_pos=False,
259
+ rel_pos_zero_init=True,
260
+ window_size=0,
261
+ window_block_indexes=(),
262
+ residual_block_indexes=(),
263
+ use_act_checkpoint=True,
264
+ pretrain_img_size=224,
265
+ pretrain_use_cls_token=True,
266
+ out_feature="last_feat",
267
+ ):
268
+ """
269
+ Args:
270
+ img_size (int): Input image size.
271
+ patch_size (int): Patch size.
272
+ in_chans (int): Number of input image channels.
273
+ embed_dim (int): Patch embedding dimension.
274
+ depth (int): Depth of ViT.
275
+ num_heads (int): Number of attention heads in each ViT block.
276
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
277
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
278
+ drop_path_rate (float): Stochastic depth rate.
279
+ norm_layer (nn.Module): Normalization layer.
280
+ act_layer (nn.Module): Activation layer.
281
+ use_abs_pos (bool): If True, use absolute positional embeddings.
282
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
283
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
284
+ window_size (int): Window size for window attention blocks.
285
+ window_block_indexes (list): Indexes for blocks using window attention.
286
+ residual_block_indexes (list): Indexes for blocks using conv propagation.
287
+ use_act_checkpoint (bool): If True, use activation checkpointing.
288
+ pretrain_img_size (int): input image size for pretraining models.
289
+ pretrain_use_cls_token (bool): If True, pretrainig models use class token.
290
+ out_feature (str): name of the feature from the last block.
291
+ """
292
+ super().__init__()
293
+ self.pretrain_use_cls_token = pretrain_use_cls_token
294
+ self.use_act_checkpoint = use_act_checkpoint
295
+
296
+ self.patch_embed = PatchEmbed(
297
+ kernel_size=(patch_size, patch_size),
298
+ stride=(patch_size, patch_size),
299
+ in_chans=in_chans,
300
+ embed_dim=embed_dim,
301
+ )
302
+
303
+ if use_abs_pos:
304
+ # Initialize absolute positional embedding with pretrain image size.
305
+ num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
306
+ num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
307
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
308
+ else:
309
+ self.pos_embed = None
310
+
311
+ # stochastic depth decay rule
312
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
313
+
314
+ self.blocks = nn.ModuleList()
315
+ for i in range(depth):
316
+ block = Block(
317
+ dim=embed_dim,
318
+ num_heads=num_heads,
319
+ mlp_ratio=mlp_ratio,
320
+ qkv_bias=qkv_bias,
321
+ drop_path=dpr[i],
322
+ norm_layer=norm_layer,
323
+ act_layer=act_layer,
324
+ use_rel_pos=use_rel_pos,
325
+ rel_pos_zero_init=rel_pos_zero_init,
326
+ window_size=window_size if i in window_block_indexes else 0,
327
+ use_residual_block=i in residual_block_indexes,
328
+ input_size=(img_size // patch_size, img_size // patch_size),
329
+ )
330
+ self.blocks.append(block)
331
+
332
+ self._out_feature_channels = {out_feature: embed_dim}
333
+ self._out_feature_strides = {out_feature: patch_size}
334
+ self._out_features = [out_feature]
335
+
336
+ if self.pos_embed is not None:
337
+ trunc_normal_(self.pos_embed, std=0.02)
338
+
339
+ self.apply(self._init_weights)
340
+
341
+ def _init_weights(self, m):
342
+ if isinstance(m, nn.Linear):
343
+ trunc_normal_(m.weight, std=0.02)
344
+ if isinstance(m, nn.Linear) and m.bias is not None:
345
+ nn.init.constant_(m.bias, 0)
346
+ elif isinstance(m, nn.LayerNorm):
347
+ nn.init.constant_(m.bias, 0)
348
+ nn.init.constant_(m.weight, 1.0)
349
+
350
+ def forward(self, x):
351
+ x = self.patch_embed(x)
352
+ if self.pos_embed is not None:
353
+ x = x + get_abs_pos(
354
+ self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
355
+ )
356
+
357
+ for blk in self.blocks:
358
+ if self.use_act_checkpoint:
359
+ x = checkpoint.checkpoint(blk, x)
360
+ else:
361
+ x = blk(x)
362
+
363
+ return x.permute(0, 3, 1, 2)
364
+
365
+
366
+ class ViT_FPN(Backbone):
367
+ def __init__(self, bottom_up=None, top_block=None, out_channels=None, strides=None, vit_out_dim=None):
368
+ super(ViT_FPN, self).__init__()
369
+ assert isinstance(bottom_up, Backbone)
370
+ self.bottom_up = bottom_up
371
+ self.top_block = top_block
372
+
373
+ self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
374
+ self._out_features = list(self._out_feature_strides.keys())
375
+ self._out_feature_channels = {k: out_channels for k in self._out_features}
376
+ self._size_divisibility = strides[2]
377
+
378
+ self.maxpool = nn.MaxPool2d(2, stride=2)
379
+ self.fpn_stride_16_8 = nn.ConvTranspose2d(vit_out_dim, vit_out_dim, 2, stride=2, bias=False)
380
+ self.fpn_stride8_conv1 = nn.Conv2d(in_channels=vit_out_dim, out_channels=out_channels, kernel_size=1, bias=False)
381
+ self.fpn_stride8_norm1 = nn.LayerNorm(out_channels)
382
+ self.fpn_stride8_conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
383
+ self.fpn_stride8_norm2 = nn.LayerNorm(out_channels)
384
+
385
+ self.fpn_stride16_conv1 = nn.Conv2d(in_channels=vit_out_dim, out_channels=out_channels, kernel_size=1, bias=False)
386
+ self.fpn_stride16_norm1 = nn.LayerNorm(out_channels)
387
+ self.fpn_stride16_conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
388
+ self.fpn_stride16_norm2 = nn.LayerNorm(out_channels)
389
+
390
+ self.fpn_stride32_conv1 = nn.Conv2d(in_channels=vit_out_dim, out_channels=out_channels, kernel_size=1, bias=False)
391
+ self.fpn_stride32_norm1 = nn.LayerNorm(out_channels)
392
+ self.fpn_stride32_conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
393
+ self.fpn_stride32_norm2 = nn.LayerNorm(out_channels)
394
+
395
+ def forward(self, x):
396
+ vit_output_featuremap = self.bottom_up(x)
397
+
398
+ stride8_feature = self.fpn_stride_16_8(vit_output_featuremap)
399
+ stride8_feature = self.fpn_stride8_norm1(self.fpn_stride8_conv1(stride8_feature)
400
+ .permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
401
+ stride8_feature = self.fpn_stride8_norm2(self.fpn_stride8_conv2(stride8_feature)
402
+ .permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
403
+
404
+ stride32_feature = self.maxpool(vit_output_featuremap)
405
+ stride32_feature = self.fpn_stride32_norm1(self.fpn_stride32_conv1(stride32_feature)
406
+ .permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
407
+ stride32_feature = self.fpn_stride32_norm2(self.fpn_stride32_conv2(stride32_feature)
408
+ .permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
409
+
410
+ stride16_feature = self.fpn_stride16_norm1(self.fpn_stride16_conv1(vit_output_featuremap).
411
+ permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
412
+ stride16_feature = self.fpn_stride16_norm2(self.fpn_stride16_conv2(stride16_feature)
413
+ .permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
414
+
415
+ results = [stride8_feature, stride16_feature, stride32_feature]
416
+
417
+ results.extend(self.top_block(stride32_feature))
418
+
419
+ assert len(self._out_features) == len(results)
420
+ fpn_out = {f: res for f, res in zip(self._out_features, results)}
421
+
422
+ return fpn_out
423
+ @property
424
+ def size_divisibility(self):
425
+ return self._size_divisibility
426
+
427
+ def output_shape(self):
428
+ return {
429
+ name: ShapeSpec(
430
+ channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
431
+ )
432
+ for name in self._out_features
433
+ }
434
+
435
+
436
+ @BACKBONE_REGISTRY.register()
437
+ def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
438
+ embed_dim = 768
439
+ vit_out_dim = embed_dim
440
+ bottom_up = ViT( # Single-scale ViT backbone
441
+ img_size=1024,
442
+ patch_size=16,
443
+ embed_dim=embed_dim,
444
+ depth=12,
445
+ num_heads=12,
446
+ drop_path_rate=0.1,
447
+ window_size=14,
448
+ mlp_ratio=4,
449
+ qkv_bias=True,
450
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
451
+ window_block_indexes=[
452
+ # 2, 5, 8 11 for global attention
453
+ 0,
454
+ 1,
455
+ 3,
456
+ 4,
457
+ 6,
458
+ 7,
459
+ 9,
460
+ 10,
461
+ ],
462
+ residual_block_indexes=[],
463
+ use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
464
+ use_rel_pos=True,
465
+ out_feature="last_feat",)
466
+
467
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
468
+ assert out_channels == 256 or out_channels == 768 or out_channels == 1024
469
+ backbone = ViT_FPN(bottom_up=bottom_up,
470
+ top_block=LastLevelP6P7_P5(out_channels, out_channels),
471
+ out_channels=out_channels,
472
+ strides=[8, 16, 32, 64, 128],
473
+ vit_out_dim=vit_out_dim)
474
+ return backbone
475
+
476
+
477
+ @BACKBONE_REGISTRY.register()
478
+ def build_vit_fpn_backbone_large(cfg, input_shape: ShapeSpec):
479
+ window_block_indexes = (list(range(0, 5)) + list(range(6, 11)) + list(range(12, 17)) + list(range(18, 23)))
480
+ embed_dim = 1024
481
+ vit_out_dim = embed_dim
482
+ bottom_up = ViT( # Single-scale ViT backbone
483
+ img_size=1024,
484
+ patch_size=16,
485
+ embed_dim=embed_dim,
486
+ depth=24,
487
+ num_heads=16,
488
+ drop_path_rate=0.4,
489
+ window_size=14,
490
+ mlp_ratio=4,
491
+ qkv_bias=True,
492
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
493
+ window_block_indexes=window_block_indexes,
494
+ residual_block_indexes=[],
495
+ use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
496
+ use_rel_pos=True,
497
+ out_feature="last_feat",)
498
+
499
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
500
+ assert out_channels == 256 or out_channels == 768 or out_channels == 1024
501
+ backbone = ViT_FPN(bottom_up=bottom_up,
502
+ top_block=LastLevelP6P7_P5(out_channels, out_channels),
503
+ out_channels=out_channels,
504
+ strides=[8, 16, 32, 64, 128],
505
+ vit_out_dim=vit_out_dim)
506
+ return backbone
507
+
508
+
509
+ @BACKBONE_REGISTRY.register()
510
+ def build_vit_fpn_backbone_huge(cfg, input_shape: ShapeSpec):
511
+ window_block_indexes = (list(range(0, 7)) + list(range(8, 15)) + list(range(16, 23)) + list(range(24, 31)))
512
+ embed_dim = 1280
513
+ vit_out_dim = embed_dim
514
+ bottom_up = ViT( # Single-scale ViT backbone
515
+ img_size=1024,
516
+ patch_size=16,
517
+ embed_dim=embed_dim,
518
+ depth=32,
519
+ num_heads=16,
520
+ drop_path_rate=0.5,
521
+ window_size=14,
522
+ mlp_ratio=4,
523
+ qkv_bias=True,
524
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
525
+ window_block_indexes=window_block_indexes,
526
+ residual_block_indexes=[],
527
+ use_act_checkpoint=cfg.USE_ACT_CHECKPOINT,
528
+ use_rel_pos=True,
529
+ out_feature="last_feat",)
530
+
531
+ out_channels = cfg.MODEL.FPN.OUT_CHANNELS
532
+ assert out_channels == 256 or out_channels == 768 or out_channels == 1024
533
+ backbone = ViT_FPN(bottom_up=bottom_up,
534
+ top_block=LastLevelP6P7_P5(out_channels, out_channels),
535
+ out_channels=out_channels,
536
+ strides=[8, 16, 32, 64, 128],
537
+ vit_out_dim=vit_out_dim)
538
+ return backbone
models/grit_src/grit/modeling/meta_arch/__pycache__/grit.cpython-38.pyc ADDED
Binary file (2.49 kB). View file
 
models/grit_src/grit/modeling/meta_arch/grit.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+ import torch
3
+ from detectron2.config import configurable
4
+ from detectron2.structures import ImageList, Instances, Boxes
5
+ from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
6
+ from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
7
+
8
+
9
+ @META_ARCH_REGISTRY.register()
10
+ class GRiT(GeneralizedRCNN):
11
+ @configurable
12
+ def __init__(
13
+ self,
14
+ **kwargs):
15
+ super().__init__(**kwargs)
16
+ assert self.proposal_generator is not None
17
+
18
+ @classmethod
19
+ def from_config(cls, cfg):
20
+ ret = super().from_config(cfg)
21
+ return ret
22
+
23
+ def inference(
24
+ self,
25
+ batched_inputs: Tuple[Dict[str, torch.Tensor]],
26
+ detected_instances: Optional[List[Instances]] = None,
27
+ do_postprocess: bool = True,
28
+ ):
29
+ assert not self.training
30
+ assert detected_instances is None
31
+
32
+ images = self.preprocess_image(batched_inputs)
33
+ features = self.backbone(images.tensor)
34
+ proposals, _ = self.proposal_generator(images, features, None)
35
+ results, _ = self.roi_heads(features, proposals)
36
+ if do_postprocess:
37
+ assert not torch.jit.is_scripting(), \
38
+ "Scripting is not supported for postprocess."
39
+ return GRiT._postprocess(
40
+ results, batched_inputs, images.image_sizes)
41
+ else:
42
+ return results
43
+
44
+ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
45
+ if not self.training:
46
+ return self.inference(batched_inputs)
47
+
48
+ images = self.preprocess_image(batched_inputs)
49
+
50
+ gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
51
+
52
+ targets_task = batched_inputs[0]['task']
53
+ for anno_per_image in batched_inputs:
54
+ assert targets_task == anno_per_image['task']
55
+
56
+ features = self.backbone(images.tensor)
57
+ proposals, proposal_losses = self.proposal_generator(
58
+ images, features, gt_instances)
59
+ proposals, roihead_textdecoder_losses = self.roi_heads(
60
+ features, proposals, gt_instances, targets_task=targets_task)
61
+
62
+ losses = {}
63
+ losses.update(roihead_textdecoder_losses)
64
+ losses.update(proposal_losses)
65
+
66
+ return losses