File size: 1,326 Bytes
67a58db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import json
from transformers import PretrainedConfig
class GECToRConfig(PretrainedConfig):
    def __init__(
        self,
        model_id: str = 'bert-base-cased',
        p_dropout: float=0,
        label_pad_token: str='<PAD>',
        label_oov_token: str='<OOV>',
        d_pad_token: str='<PAD>',
        keep_label: str='$KEEP',
        correct_label: str='$CORRECT',
        incorrect_label: str='$INCORRECT',
        label_smoothing: float=0.0,
        has_add_pooling_layer: bool=True,
        initializer_range: float=0.02,
        **kwards
    ):
        super().__init__(**kwards)
        self.d_label2id = {
            "$CORRECT": 0,
            "$INCORRECT": 1,
            "<PAD>": 2
        }
        self.d_id2label = {v: k for k, v in self.d_label2id.items()}
        self.d_num_labels = len(self.d_label2id)
        self.model_id = model_id
        self.p_dropout = p_dropout
        self.label_pad_token = label_pad_token
        self.label_oov_token = label_oov_token
        self.d_pad_token = d_pad_token
        self.keep_label = keep_label
        self.correct_label = correct_label
        self.incorrect_label = incorrect_label
        self.label_smoothing = label_smoothing
        self.has_add_pooling_layer = has_add_pooling_layer
        self.initializer_range = initializer_range