from transformers import PretrainedConfig import torch class BiLSTMConfig(PretrainedConfig): def __init__(self, vocab_size=23626, embed_dim=100, num_layers=1, hidden_dim=256, dropout=0.33, output_dim=128, predict_output=10, device="cuda:0", **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.embed_dim = embed_dim self.num_layers = num_layers self.hidden_dim = hidden_dim self.dropout = dropout self.output_dim = output_dim self.predict_output = predict_output self.device = device