|
import os |
|
from dataclasses import dataclass, field |
|
from pickletools import int4 |
|
from typing import List |
|
|
|
|
|
@dataclass |
|
class GeneEmbeddModelConfig: |
|
|
|
|
|
model_input: str = "seq-struct" |
|
num_embed_hidden: int = 256 |
|
ff_hidden_dim: List = field(default_factory=lambda: [1200, 800]) |
|
feed_forward1_hidden: int = 1024 |
|
num_attention_project: int = 64 |
|
num_encoder_layers: int = 2 |
|
dropout: float = 0.3 |
|
n: int = 121 |
|
window:int = 4 |
|
relative_attns: List = field(default_factory=lambda: [int(360), int(360)]) |
|
num_attention_heads: int = 4 |
|
|
|
tokens_len: int = 0 |
|
second_input_token_len:int = 0 |
|
vocab_size: int = 0 |
|
second_input_vocab_size: int = 0 |
|
tokenizer: str = ( |
|
"overlap" |
|
) |
|
|
|
num_classes: int = 0 |
|
class_weights :List = field(default_factory=lambda: []) |
|
tokens_mapping_dict: dict = None |
|
|
|
|
|
false_input_perc:float = 0.2 |
|
|
|
model_input: str = "seq-struct" |
|
|
|
|
|
@dataclass |
|
class GeneEmbeddTrainConfig: |
|
dataset_path_train: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/train.h5ad" |
|
dataset_path_test: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/test.h5ad" |
|
labels_mapping_path:str = "/data/hbdx_ldap_local/analysis/data/sncRNA/labels_mapping_dict.pkl" |
|
device: str = "cuda" |
|
l2_weight_decay: float = 1e-5 |
|
batch_size: int = 64 |
|
|
|
batch_per_epoch:int = 0 |
|
label_smoothing_sim:float = 0.0 |
|
label_smoothing_clf:float = 0.0 |
|
|
|
|
|
learning_rate: float = 1e-3 |
|
lr_warmup_start: float = 0.1 |
|
lr_warmup_end: float = 1 |
|
|
|
|
|
warmup_epoch: int = 10 |
|
final_epoch: int = 20 |
|
|
|
top_k: int = int( |
|
0.05 * batch_size |
|
) |
|
label_smoothing: float = 0.0 |
|
cross_val: bool = False |
|
filter_seq_length:bool = True |
|
train_epoch: int = 800 |
|
max_epochs:int = 1000 |