File size: 452 Bytes
bc1bebe
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# create pretrainconfig
from transformers import PretrainedConfig

class AlexNetConfig(PretrainedConfig):
  model_type = "alexnet"
  def __init__(self, id2label=None, label2id=None, labels=[], **kwargs):    
    self.input_channels = 3
    self.output_hidden_states = True
    self.return_dict = True
    self.id2label=id2label
    self.label2id=label2id
    self.num_labels = len(labels)
    self.model_type = "alexnet"
    super().__init__(**kwargs)