Try to subclass PretrainedModel
Browse files- modeling_bert.py +2 -29
modeling_bert.py
CHANGED
@@ -22,7 +22,7 @@ import torch
|
|
22 |
import torch.nn as nn
|
23 |
import torch.nn.functional as F
|
24 |
from einops import rearrange
|
25 |
-
from transformers import
|
26 |
from .configuration_bert import JinaBertConfig
|
27 |
from transformers.models.bert.modeling_bert import (
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
@@ -295,7 +295,7 @@ class BertPreTrainingHeads(nn.Module):
|
|
295 |
return prediction_scores, seq_relationship_score
|
296 |
|
297 |
|
298 |
-
class BertPreTrainedModel(
|
299 |
"""An abstract class to handle weights initialization and
|
300 |
a simple interface for dowloading and loading pretrained models.
|
301 |
"""
|
@@ -310,33 +310,6 @@ class BertPreTrainedModel(nn.Module):
|
|
310 |
)
|
311 |
self.config = config
|
312 |
|
313 |
-
@classmethod
|
314 |
-
def from_pretrained(cls, model_name, config, *inputs, **kwargs):
|
315 |
-
"""
|
316 |
-
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
|
317 |
-
Download and cache the pre-trained model file if needed.
|
318 |
-
|
319 |
-
Params:
|
320 |
-
pretrained_model_name_or_path: either:
|
321 |
-
- a path or url to a pretrained model archive containing:
|
322 |
-
. `bert_config.json` a configuration file for the model
|
323 |
-
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
|
324 |
-
- a path or url to a pretrained model archive containing:
|
325 |
-
. `bert_config.json` a configuration file for the model
|
326 |
-
. `model.chkpt` a TensorFlow checkpoint
|
327 |
-
*inputs, **kwargs: additional input for the specific Bert class
|
328 |
-
(ex: num_labels for BertForSequenceClassification)
|
329 |
-
"""
|
330 |
-
# Instantiate model.
|
331 |
-
model = cls(config, *inputs, **kwargs)
|
332 |
-
load_return = model.load_state_dict(state_dict_from_pretrained(model_name), strict=True)
|
333 |
-
logger.info(load_return)
|
334 |
-
return model
|
335 |
-
|
336 |
-
@classmethod
|
337 |
-
def _from_config(cls, config, **kwargs):
|
338 |
-
return cls(config, **kwargs)
|
339 |
-
|
340 |
|
341 |
class BertModel(BertPreTrainedModel):
|
342 |
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|
|
|
22 |
import torch.nn as nn
|
23 |
import torch.nn.functional as F
|
24 |
from einops import rearrange
|
25 |
+
from transformers import PretrainedModel
|
26 |
from .configuration_bert import JinaBertConfig
|
27 |
from transformers.models.bert.modeling_bert import (
|
28 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
|
295 |
return prediction_scores, seq_relationship_score
|
296 |
|
297 |
|
298 |
+
class BertPreTrainedModel(PretrainedModel):
|
299 |
"""An abstract class to handle weights initialization and
|
300 |
a simple interface for dowloading and loading pretrained models.
|
301 |
"""
|
|
|
310 |
)
|
311 |
self.config = config
|
312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
313 |
|
314 |
class BertModel(BertPreTrainedModel):
|
315 |
def __init__(self, config: JinaBertConfig, add_pooling_layer=True):
|