|
from transformers import PretrainedConfig |
|
from typing import List |
|
|
|
|
|
class ChAdaViTConfig(PretrainedConfig): |
|
model_type = "chadavit" |
|
|
|
def __init__( |
|
self, |
|
img_size: List[int] = [224], |
|
in_chans: int = 1, |
|
embed_dim: int = 192, |
|
patch_size: int = 16, |
|
num_classes: int = 0, |
|
depth: int = 12, |
|
num_heads: int = 12, |
|
drop_rate: float = 0.0, |
|
drop_path_rate: float = 0.0, |
|
return_all_tokens: bool = True, |
|
max_number_channels: int = 10, |
|
**kwargs, |
|
): |
|
self.img_size = img_size |
|
self.in_chans = in_chans |
|
self.embed_dim = embed_dim |
|
self.patch_size = patch_size |
|
self.num_classes = num_classes |
|
self.depth = depth |
|
self.num_heads = num_heads |
|
self.drop_rate = drop_rate |
|
self.drop_path_rate = drop_path_rate |
|
self.return_all_tokens = return_all_tokens |
|
self.max_number_channels = max_number_channels |
|
super().__init__(**kwargs) |
|
|