mtgv
/

Image Classification
mtgv commited on
Commit
52b75dc
1 Parent(s): ed5a071

add ep800 ckpt

Browse files
epoch_800.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff99b4495a8f6d96bcfef0a38737792748a98353f2fb90edaef5d9508bee6e4a
3
+ size 3404723238
mae_lama-base-p16_8xb512-amp-coslr-800e_in1k.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ auto_scale_lr = dict(base_batch_size=4096)
2
+ data_preprocessor = dict(
3
+ mean=[
4
+ 123.675,
5
+ 116.28,
6
+ 103.53,
7
+ ],
8
+ non_blocking=True,
9
+ std=[
10
+ 58.395,
11
+ 57.12,
12
+ 57.375,
13
+ ],
14
+ to_rgb=True,
15
+ type='SelfSupDataPreprocessor')
16
+ data_root = '/workdir/ILSVRC2012/'
17
+ dataset_type = 'ImageNet'
18
+ default_hooks = dict(
19
+ checkpoint=dict(interval=1, max_keep_ckpts=3, type='CheckpointHook'),
20
+ logger=dict(interval=20, type='LoggerHook'),
21
+ param_scheduler=dict(type='ParamSchedulerHook'),
22
+ sampler_seed=dict(type='DistSamplerSeedHook'),
23
+ timer=dict(type='IterTimerHook'),
24
+ visualization=dict(enable=False, type='VisualizationHook'))
25
+ default_scope = 'mmpretrain'
26
+ env_cfg = dict(
27
+ cudnn_benchmark=True,
28
+ dist_cfg=dict(backend='nccl'),
29
+ mp_cfg=dict(mp_start_method='spawn', opencv_num_threads=0))
30
+ launcher = 'pytorch'
31
+ load_from = None
32
+ log_level = 'INFO'
33
+ model = dict(
34
+ backbone=dict(arch='b', mask_ratio=0.75, patch_size=16, type='MAELLaMA'),
35
+ head=dict(
36
+ loss=dict(criterion='L2', type='PixelReconstructionLoss'),
37
+ norm_pix=True,
38
+ patch_size=16,
39
+ type='MAEPretrainHead'),
40
+ init_cfg=[
41
+ dict(distribution='uniform', layer='Linear', type='Xavier'),
42
+ dict(bias=0.0, layer='LayerNorm', type='Constant', val=1.0),
43
+ ],
44
+ neck=dict(
45
+ decoder_depth=8,
46
+ decoder_embed_dim=512,
47
+ decoder_num_heads=16,
48
+ embed_dim=768,
49
+ in_chans=3,
50
+ mlp_ratio=4.0,
51
+ patch_size=16,
52
+ type='MAEPretrainDecoder'),
53
+ type='MAE')
54
+ optim_wrapper = dict(
55
+ loss_scale='dynamic',
56
+ optimizer=dict(
57
+ betas=(
58
+ 0.9,
59
+ 0.95,
60
+ ), lr=0.0024, type='AdamW', weight_decay=0.05),
61
+ paramwise_cfg=dict(
62
+ custom_keys=dict(
63
+ bias=dict(decay_mult=0.0),
64
+ cls_token=dict(decay_mult=0.0),
65
+ ln=dict(decay_mult=0.0),
66
+ mask_token=dict(decay_mult=0.0),
67
+ pos_embed=dict(decay_mult=0.0))),
68
+ type='AmpOptimWrapper')
69
+ param_scheduler = [
70
+ dict(
71
+ begin=0,
72
+ by_epoch=True,
73
+ convert_to_iter_based=True,
74
+ end=40,
75
+ start_factor=1e-09,
76
+ type='LinearLR'),
77
+ dict(
78
+ T_max=760,
79
+ begin=40,
80
+ by_epoch=True,
81
+ convert_to_iter_based=True,
82
+ end=800,
83
+ type='CosineAnnealingLR'),
84
+ ]
85
+ randomness = dict(deterministic=False, diff_rank_seed=True, seed=0)
86
+ resume = True
87
+ train_cfg = dict(max_epochs=800, type='EpochBasedTrainLoop')
88
+ train_dataloader = dict(
89
+ batch_size=512,
90
+ collate_fn=dict(type='default_collate'),
91
+ dataset=dict(
92
+ data_root='/workdir/ILSVRC2012/',
93
+ pipeline=[
94
+ dict(type='LoadImageFromFile'),
95
+ dict(
96
+ backend='pillow',
97
+ crop_ratio_range=(
98
+ 0.2,
99
+ 1.0,
100
+ ),
101
+ interpolation='bicubic',
102
+ scale=224,
103
+ type='RandomResizedCrop'),
104
+ dict(prob=0.5, type='RandomFlip'),
105
+ dict(type='PackInputs'),
106
+ ],
107
+ split='train',
108
+ type='ImageNet'),
109
+ num_workers=8,
110
+ persistent_workers=True,
111
+ pin_memory=True,
112
+ sampler=dict(shuffle=True, type='DefaultSampler'))
113
+ train_pipeline = [
114
+ dict(type='LoadImageFromFile'),
115
+ dict(
116
+ backend='pillow',
117
+ crop_ratio_range=(
118
+ 0.2,
119
+ 1.0,
120
+ ),
121
+ interpolation='bicubic',
122
+ scale=224,
123
+ type='RandomResizedCrop'),
124
+ dict(prob=0.5, type='RandomFlip'),
125
+ dict(type='PackInputs'),
126
+ ]
127
+ vis_backends = [
128
+ dict(type='LocalVisBackend'),
129
+ ]
130
+ visualizer = dict(
131
+ type='UniversalVisualizer', vis_backends=[
132
+ dict(type='LocalVisBackend'),
133
+ ])
134
+ work_dir = './work_dirs/mae_lama-base-p16_8xb512-amp-coslr-800e_in1k'