# create pretrainconfig | |
from transformers import PretrainedConfig | |
class AlexNetConfig(PretrainedConfig): | |
model_type = "alexnet" | |
def __init__(self, id2label=None, label2id=None, labels=[], **kwargs): | |
self.input_channels = 3 | |
self.output_hidden_states = True | |
self.return_dict = True | |
self.id2label=id2label | |
self.label2id=label2id | |
self.num_labels = len(labels) | |
self.model_type = "alexnet" | |
super().__init__(**kwargs) | |