import torch from fvcore.nn import FlopCountAnalysis, flop_count_table from torch.nn import MultiheadAttention from models.beit.st_beit import BeitConfig, BeitModel from models.temporal_model import (STAdapter, TemporalAttention, WindowTemporalAttention) def mem_stat(): mem = torch.cuda.max_memory_allocated() / 1024 / 1024 print(f"max memory allocated: {mem}MB") def build_backbone(tm_block="timesformer"): """TODO: Docstring for build_backbone. Returns: TODO """ if tm_block == "timesformer": other_cfg = dict( num_frames=12, temporal_model_block="timesformer", temporal_model_config={} ) elif tm_block == "st_adapter": other_cfg = dict( num_frames=12, temporal_model_block="st_adapter", temporal_model_config={} ) elif tm_block == "xclip": other_cfg = dict( num_frames=12, temporal_model_block="xclip", temporal_model_config={} ) elif tm_block == "none": other_cfg = dict(num_frames=12, temporal_model_block="none", temporal_model_config={}) elif tm_block == "wa_2x2": other_cfg = dict( num_frames=12, temporal_model_block="window_attention", temporal_model_config=dict(window_size=(2, 2)), ) elif tm_block == "wa_7x7": other_cfg = dict( num_frames=12, temporal_model_block="window_attention", temporal_model_config=dict(window_size=(7, 7)), ) else: raise ValueError("not exist") model_card = "microsoft/beit-base-patch16-224-pt22k-ft22k" model_config = BeitConfig.from_pretrained(model_card, image_size=224, **other_cfg) model = BeitModel(model_config) return model # model = TemporalAttention() model = build_backbone("st_adapter") model.gradient_checkpointing_enable() model.cuda() for i in range(3): x = torch.rand(32, 12, 3, 224, 224, requires_grad=True) x = x.cuda() x = x.requires_grad_() y = model(x) loss = y[0].mean() loss.backward() mem_stat() # flops = FlopCountAnalysis(model, x) # print(flop_count_table(flops))