File size: 1,010 Bytes
a0a61c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from transformers import PretrainedConfig
from typing import List
class ChAdaViTConfig(PretrainedConfig):
model_type = "chadavit"
def __init__(
self,
img_size: List[int] = [224],
in_chans: int = 1,
embed_dim: int = 192,
patch_size: int = 16,
num_classes: int = 0,
depth: int = 12,
num_heads: int = 12,
drop_rate: float = 0.0,
drop_path_rate: float = 0.0,
return_all_tokens: bool = True,
max_number_channels: int = 10,
**kwargs,
):
self.img_size = img_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.patch_size = patch_size
self.num_classes = num_classes
self.depth = depth
self.num_heads = num_heads
self.drop_rate = drop_rate
self.drop_path_rate = drop_path_rate
self.return_all_tokens = return_all_tokens
self.max_number_channels = max_number_channels
super().__init__(**kwargs)
|