File size: 2,512 Bytes
abe4cc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import PretrainedConfig

class ZettHypernetConfig(PretrainedConfig):
    def __init__(
        self,
        hn_model_name_or_path: str = "roberta-base",
        hn_surface_maxlen: int = 16,
        hn_n_layers: int = 3,
        n_embd: int = 768,
        hn_hidden_size: int = None,
        hn_intermediate_size: int = None,
        hn_rescale_embeddings: bool = False,
        use_unigram_bias: bool = False,
        hn_embed_target_priors: bool = False,
        hn_add_inter_token_attention: bool = False,
        hn_inter_token_attention_bias_by_priors: bool = False,
        hn_inter_token_attention_bias_scaler: float = 1.0,
        hn_n_inter_token_blocks: int = 16,
        hn_language_adapter_bottleneck_dim: int = 0,
        hn_embed_using_source_embeddings: bool = False,
        hn_concat_last_hidden_state: bool = False,
        hn_single_head: bool = False,
        hn_predict_bias: bool = True,
        hn_num_attention_heads: int = None,
        hn_embed_lang_id: bool = False,
        hn_model_type: str = "roberta",
        n_langs: int = None,  # set in train.py
        **kwargs
    ):
        super().__init__(**kwargs)

        self.model_type = "zett_hypernetwork"
        self.hn_model_name_or_path = hn_model_name_or_path
        self.hn_surface_maxlen = hn_surface_maxlen
        self.hn_n_layers = hn_n_layers
        self.n_embd = n_embd
        self.hn_hidden_size = hn_hidden_size
        self.hn_intermediate_size = hn_intermediate_size
        self.hn_rescale_embeddings = hn_rescale_embeddings
        self.use_unigram_bias = use_unigram_bias
        self.hn_embed_target_priors = hn_embed_target_priors
        self.hn_add_inter_token_attention = hn_add_inter_token_attention
        self.hn_inter_token_attention_bias_by_priors = (
            hn_inter_token_attention_bias_by_priors
        )
        self.hn_inter_token_attention_bias_scaler = hn_inter_token_attention_bias_scaler
        self.hn_n_inter_token_blocks = hn_n_inter_token_blocks
        self.hn_language_adapter_bottleneck_dim = hn_language_adapter_bottleneck_dim
        self.hn_embed_using_source_embeddings = hn_embed_using_source_embeddings
        self.hn_concat_last_hidden_state = hn_concat_last_hidden_state
        self.hn_single_head = hn_single_head
        self.hn_predict_bias = hn_predict_bias
        self.hn_num_attention_heads = hn_num_attention_heads
        self.hn_embed_lang_id = hn_embed_lang_id
        self.hn_model_type = hn_model_type
        self.n_langs = n_langs