from transformers import PretrainedConfig from typing import List class ChAdaViTConfig(PretrainedConfig): model_type = "chadavit" def __init__( self, img_size: List[int] = [224], in_chans: int = 1, embed_dim: int = 192, patch_size: int = 16, num_classes: int = 0, depth: int = 12, num_heads: int = 12, drop_rate: float = 0.0, drop_path_rate: float = 0.0, return_all_tokens: bool = True, max_number_channels: int = 10, **kwargs, ): self.img_size = img_size self.in_chans = in_chans self.embed_dim = embed_dim self.patch_size = patch_size self.num_classes = num_classes self.depth = depth self.num_heads = num_heads self.drop_rate = drop_rate self.drop_path_rate = drop_path_rate self.return_all_tokens = return_all_tokens self.max_number_channels = max_number_channels super().__init__(**kwargs)