|
from transformers.models.codegen import CodeGenPreTrainedModel, CodeGenModel |
|
from transformers import PreTrainedTokenizerBase |
|
from .modeling_measurement_pred import MeasurementPredictorMixin |
|
from .configuration_code_gen_measuremet_pred import CodeGenMeasurementPredictorConfig |
|
|
|
|
|
class CodeGenMeasurementPredictor(CodeGenPreTrainedModel, MeasurementPredictorMixin): |
|
config_class = CodeGenMeasurementPredictorConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.transformer = CodeGenModel(config) |
|
self.post_init() |
|
|
|
def set_pad_token(self, tokenizer: PreTrainedTokenizerBase): |
|
pad_token = ' .' |
|
pad_token_id = tokenizer.encode(pad_token)[0] |
|
tokenizer.pad_token = pad_token |
|
tokenizer.pad_token_id = pad_token_id |
|
|