keithhon commited on
Commit
917ff2b
1 Parent(s): 50d616c

Upload dalle/utils/config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dalle/utils/config.py +123 -0
dalle/utils/config.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ from typing import Optional, List
8
+ from dataclasses import dataclass, field
9
+ from omegaconf import OmegaConf
10
+
11
+
12
+ @dataclass
13
+ class DataConfig:
14
+ dataset: Optional[str] = None
15
+ tokenizer_type: str = 'CharBPE'
16
+ context_length: int = 64
17
+ image_resolution: int = 256
18
+ transforms: str = 'dalle-vqvae'
19
+ bpe_pdrop: Optional[float] = None
20
+
21
+
22
+ @dataclass
23
+ class Stage1Hparams:
24
+ double_z: bool = False
25
+ z_channels: int = 256
26
+ resolution: int = 256
27
+ in_channels: int = 3
28
+ out_ch: int = 3
29
+ ch: int = 128
30
+ ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
31
+ num_res_blocks: int = 2
32
+ attn_resolutions: List[int] = field(default_factory=lambda: [16])
33
+ pdrop: float = 0.0
34
+
35
+
36
+ @dataclass
37
+ class Stage2Hparams:
38
+ embed_dim: int = 1536
39
+ n_layers: int = 42
40
+ n_heads: int = 24
41
+ n_dense_layers: int = 42
42
+ ctx_len_img: int = 256
43
+ ctx_len_txt: int = 64
44
+ embd_pdrop: float = 0.0
45
+ resid_pdrop: float = 0.0
46
+ attn_pdrop: float = 0.0
47
+ mlp_bias: bool = True
48
+ attn_bias: bool = True
49
+ gelu_use_approx: bool = False
50
+ use_head_txt: bool = True
51
+ n_classes: Optional[int] = None
52
+
53
+
54
+ @dataclass
55
+ class Stage1Config:
56
+ type: str = 'vqgan'
57
+ embed_dim: int = 256
58
+ n_embed: int = 16384
59
+ hparams: Stage1Hparams = Stage1Hparams()
60
+
61
+
62
+ @dataclass
63
+ class Stage2Config:
64
+ type: str = 'transformer1d'
65
+ vocab_size_txt: int = 16384
66
+ vocab_size_img: int = 16384
67
+ use_cls_cond: Optional[bool] = None
68
+ hparams: Stage2Hparams = Stage2Hparams()
69
+
70
+
71
+ @dataclass
72
+ class WarmupConfig:
73
+ epoch: int = 1
74
+ multiplier: int = 1
75
+ buffer_epoch: int = 0
76
+ min_lr: float = 0.0
77
+ mode: str = 'fix'
78
+ peak_lr: float = 1e-4
79
+ start_from_zero: bool = True
80
+
81
+
82
+ @dataclass
83
+ class OptConfig:
84
+ opt_type: str = 'adamW'
85
+ base_lr: float = 1e-4
86
+ weight_decay: float = 1e-4
87
+ betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
88
+ grad_clip_norm: float = 1.0
89
+
90
+ sched_type: str = 'cosine'
91
+ max_steps: int = 0
92
+ min_lr: float = 0.0
93
+
94
+
95
+ @dataclass
96
+ class ExpConfig:
97
+ local_batch_size: int = 4
98
+ total_batch_size: int = 512
99
+ valid_batch_size: int = 32
100
+ epochs: int = 10
101
+ save_ckpt_freq: int = 2
102
+ test_freq: int = 1
103
+ use_amp: bool = True
104
+
105
+
106
+ @dataclass
107
+ class DefaultConfig:
108
+ dataset: DataConfig = DataConfig()
109
+ stage1: Stage1Config = Stage1Config()
110
+ stage2: Stage2Config = Stage2Config()
111
+
112
+
113
+ @dataclass
114
+ class FineTuningConfig:
115
+ dataset: DataConfig = DataConfig()
116
+ stage1: Stage1Config = Stage1Config()
117
+ stage2: Stage2Config = Stage2Config()
118
+ optimizer: OptConfig = OptConfig()
119
+ experiment: ExpConfig = ExpConfig()
120
+
121
+
122
+ def get_base_config(use_default=True):
123
+ return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)