Yak-hbdx's picture
uploaded TransfoRNA repo
0b11a42 verified
raw
history blame
2.66 kB
import os
from dataclasses import dataclass, field
from pickletools import int4
from typing import List
@dataclass
class GeneEmbeddModelConfig:
# input dim for the embedding and positional encoders
# as well as all k,q,v input/output dims
model_input: str = "seq-struct"
num_embed_hidden: int = 256
ff_hidden_dim: List = field(default_factory=lambda: [1200, 800])
feed_forward1_hidden: int = 1024
num_attention_project: int = 64
num_encoder_layers: int = 2
dropout: float = 0.3
n: int = 121
window:int = 4
relative_attns: List = field(default_factory=lambda: [int(360), int(360)])
num_attention_heads: int = 4
tokens_len: int = 0 #will be infered later
second_input_token_len:int = 0 # is infered in runtime
vocab_size: int = 0 # is infered in runtime
second_input_vocab_size: int = 0 # is infered in runtime
tokenizer: str = (
"overlap" # either overlap or no_overlap or overlap_multi_window
)
# how many extra window sizes other than deafault window
num_classes: int = 0 #will be infered in runtime
class_weights :List = field(default_factory=lambda: [])
tokens_mapping_dict: dict = None
#false input percentage
false_input_perc:float = 0.2
model_input: str = "seq-struct"
@dataclass
class GeneEmbeddTrainConfig:
dataset_path_train: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/train.h5ad"
dataset_path_test: str = "/data/hbdx_ldap_local/analysis/data/sncRNA/test.h5ad"
labels_mapping_path:str = "/data/hbdx_ldap_local/analysis/data/sncRNA/labels_mapping_dict.pkl"
device: str = "cuda"
l2_weight_decay: float = 1e-5
batch_size: int = 64
batch_per_epoch:int = 0 #will be infered later
label_smoothing_sim:float = 0.0
label_smoothing_clf:float = 0.0
# learning rate
learning_rate: float = 1e-3 # final learning rate ie 'lr annealed to'
lr_warmup_start: float = 0.1 # start of lr before initial linear warmup section
lr_warmup_end: float = 1 # end of linear warmup section , annealing begin
# TODO: 122 is the number of train batches per epoch, should be infered and set
# warmup batch should be in the form epoch*(train batch per epoch)
warmup_epoch: int = 10 # how many batches linear warm up for
final_epoch: int = 20 # final batch of training when want learning rate
top_k: int = int(
0.05 * batch_size
) # if the corresponding rna/GE appears in the top k, the correctly classified
label_smoothing: float = 0.0
cross_val: bool = False
filter_seq_length:bool = True
train_epoch: int = 800
max_epochs:int = 1000