|
from transformers import PretrainedConfig |
|
|
|
class ZettHypernetConfig(PretrainedConfig): |
|
def __init__( |
|
self, |
|
hn_model_name_or_path: str = "roberta-base", |
|
hn_surface_maxlen: int = 16, |
|
hn_n_layers: int = 3, |
|
n_embd: int = 768, |
|
hn_hidden_size: int = None, |
|
hn_intermediate_size: int = None, |
|
hn_rescale_embeddings: bool = False, |
|
use_unigram_bias: bool = False, |
|
hn_embed_target_priors: bool = False, |
|
hn_add_inter_token_attention: bool = False, |
|
hn_inter_token_attention_bias_by_priors: bool = False, |
|
hn_inter_token_attention_bias_scaler: float = 1.0, |
|
hn_n_inter_token_blocks: int = 16, |
|
hn_language_adapter_bottleneck_dim: int = 0, |
|
hn_embed_using_source_embeddings: bool = False, |
|
hn_concat_last_hidden_state: bool = False, |
|
hn_single_head: bool = False, |
|
hn_predict_bias: bool = True, |
|
hn_num_attention_heads: int = None, |
|
hn_embed_lang_id: bool = False, |
|
hn_model_type: str = "roberta", |
|
n_langs: int = None, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
self.model_type = "zett_hypernetwork" |
|
self.hn_model_name_or_path = hn_model_name_or_path |
|
self.hn_surface_maxlen = hn_surface_maxlen |
|
self.hn_n_layers = hn_n_layers |
|
self.n_embd = n_embd |
|
self.hn_hidden_size = hn_hidden_size |
|
self.hn_intermediate_size = hn_intermediate_size |
|
self.hn_rescale_embeddings = hn_rescale_embeddings |
|
self.use_unigram_bias = use_unigram_bias |
|
self.hn_embed_target_priors = hn_embed_target_priors |
|
self.hn_add_inter_token_attention = hn_add_inter_token_attention |
|
self.hn_inter_token_attention_bias_by_priors = ( |
|
hn_inter_token_attention_bias_by_priors |
|
) |
|
self.hn_inter_token_attention_bias_scaler = hn_inter_token_attention_bias_scaler |
|
self.hn_n_inter_token_blocks = hn_n_inter_token_blocks |
|
self.hn_language_adapter_bottleneck_dim = hn_language_adapter_bottleneck_dim |
|
self.hn_embed_using_source_embeddings = hn_embed_using_source_embeddings |
|
self.hn_concat_last_hidden_state = hn_concat_last_hidden_state |
|
self.hn_single_head = hn_single_head |
|
self.hn_predict_bias = hn_predict_bias |
|
self.hn_num_attention_heads = hn_num_attention_heads |
|
self.hn_embed_lang_id = hn_embed_lang_id |
|
self.hn_model_type = hn_model_type |
|
self.n_langs = n_langs |
|
|