File size: 792 Bytes
3dbfcc9
446b3a0
3dbfcc9
 
 
 
 
 
 
 
 
 
 
446b3a0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers.models.codegen import CodeGenPreTrainedModel, CodeGenModel
from transformers import PreTrainedTokenizerBase
from .modeling_measurement_pred import MeasurementPredictorMixin
from .configuration_code_gen_measuremet_pred import CodeGenMeasurementPredictorConfig


class CodeGenMeasurementPredictor(CodeGenPreTrainedModel, MeasurementPredictorMixin):
    config_class = CodeGenMeasurementPredictorConfig

    def __init__(self, config):
        super().__init__(config)
        self.transformer = CodeGenModel(config)
        self.post_init()
    
    def set_pad_token(self, tokenizer: PreTrainedTokenizerBase):
        pad_token = ' .'
        pad_token_id = tokenizer.encode(pad_token)[0]
        tokenizer.pad_token = pad_token
        tokenizer.pad_token_id = pad_token_id