File size: 1,904 Bytes
314a644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import PretrainedConfig
from transformers.utils import logging
from transformers.models.esm import EsmConfig
from transformers.models.bert import BertConfig

logger = logging.get_logger(__name__)


class ProtSTConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`ProtSTModel`].

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        protein_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`EsmForProteinRepresentation`].
        text_config (`dict`, *optional*):
            Dictionary of configuration options used to initialize [`BertForPubMed`].
    ```"""

    model_type = "protst"

    def __init__(
        self,
        protein_config=None,
        text_config=None,
        **kwargs,
    ):
        super().__init__(**kwargs)

        if protein_config is None:
            protein_config = {}
            logger.info("`protein_config` is `None`. Initializing the `ProtSTTextConfig` with default values.")

        if text_config is None:
            text_config = {}
            logger.info("`text_config` is `None`. Initializing the `ProtSTVisionConfig` with default values.")

        self.protein_config = EsmConfig(**protein_config)
        self.text_config = BertConfig(**text_config)

    @classmethod
    def from_protein_text_configs(
        cls, protein_config: EsmConfig, text_config: BertConfig, **kwargs
    ):
        r"""
        Instantiate a [`ProtSTConfig`] (or a derived class) from ProtST text model configuration. Returns:
            [`ProtSTConfig`]: An instance of a configuration object
        """

        return cls(protein_config=protein_config.to_dict(), text_config=text_config.to_dict(), **kwargs)