|
from transformers import PretrainedConfig |
|
|
|
import torch |
|
|
|
class BiLSTMConfig(PretrainedConfig): |
|
def __init__(self, vocab_size=23626, embed_dim=100, |
|
num_layers=1, hidden_dim=256, dropout=0.33, |
|
output_dim=128, predict_output=10, device="cuda:0", **kwargs): |
|
|
|
super().__init__(**kwargs) |
|
self.vocab_size = vocab_size |
|
self.embed_dim = embed_dim |
|
self.num_layers = num_layers |
|
self.hidden_dim = hidden_dim |
|
self.dropout = dropout |
|
self.output_dim = output_dim |
|
self.predict_output = predict_output |
|
self.device = device |