|
from transformers import PretrainedConfig |
|
from typing import List |
|
class GCNConfig(PretrainedConfig): |
|
model_type = "gcn" |
|
|
|
def __init__( |
|
self, |
|
input_feature: int=64, |
|
emb_input: int=20, |
|
hidden_size: int=64, |
|
n_layers: int=6, |
|
num_classes: int=1, |
|
|
|
smiles: List[str] = None, |
|
processor_class: str = "SmilesProcessor", |
|
**kwargs, |
|
): |
|
|
|
self.input_feature = input_feature |
|
self.emb_input = emb_input |
|
self.hidden_size = hidden_size |
|
self.n_layers = n_layers |
|
self.num_classes = num_classes |
|
|
|
self.smiles = smiles |
|
self.processor_class = processor_class |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
if __name__ == "__main__": |
|
gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1, smiles=["C", "CC", "CCC"], processor_class="SmilesProcessor") |
|
gcn_config.save_pretrained("custom-gcn") |