File size: 452 Bytes
bc1bebe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
# create pretrainconfig
from transformers import PretrainedConfig
class AlexNetConfig(PretrainedConfig):
model_type = "alexnet"
def __init__(self, id2label=None, label2id=None, labels=[], **kwargs):
self.input_channels = 3
self.output_hidden_states = True
self.return_dict = True
self.id2label=id2label
self.label2id=label2id
self.num_labels = len(labels)
self.model_type = "alexnet"
super().__init__(**kwargs)
|