File size: 4,786 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from typing import List, Dict
import torch
from torch import nn

try:
    from transformers import AutoTokenizer, AutoModelForTokenClassification
except ImportError:
    from ditk import logging
    logging.warning("not found transformer, please install it using: pip install transformers")
from ding.utils import MODEL_REGISTRY


@MODEL_REGISTRY.register('language_transformer')
class LanguageTransformer(nn.Module):
    """
    Overview:
        The LanguageTransformer network. Download a pre-trained language model and add head on it.
    Interfaces:
        ``__init__``, ``forward``
    """

    def __init__(
            self,
            model_name: str = "bert-base-uncased",
            add_linear: bool = False,
            embedding_size: int = 128,
            freeze_encoder: bool = True
    ) -> None:
        """
        Overview:
            Init the LanguageTransformer Model according to input arguments.
        Arguments:
            - model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
            - add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
            ``False``.
            - embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
            - freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
            defaults to be ``True``.
        """
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForTokenClassification.from_pretrained(model_name)

        # Freeze transformer encoder and only train the linear layer
        if freeze_encoder:
            for param in self.model.parameters():
                param.requires_grad = False

        if add_linear:
            # Add a small, adjustable linear layer on top of language model tuned through RL
            self.embedding_size = embedding_size
            self.linear = nn.Linear(
                self.model.config.hidden_size, embedding_size
            )  # 768 for bert-base-uncased, distilbert-base-uncased
        else:
            self.linear = None

    def _calc_embedding(self, x: list) -> torch.Tensor:
        # ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer,
        # the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach
        # the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is
        # exactly ``max_length``, which can enable batch-wise computing.
        input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
        output = self.model(**input, output_hidden_states=True)
        # Get last layer hidden states
        last_hidden_states = output.hidden_states[-1]
        # Get [CLS] hidden states
        sentence_embedding = last_hidden_states[:, 0, :]  # len(input_list) x hidden_size

        if self.linear:
            sentence_embedding = self.linear(sentence_embedding)  # len(input_list) x embedding_size

        return sentence_embedding

    def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict:
        """
        Overview:
            LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
        Arguments:
            - train_samples (:obj:`List[str]`): One list of strings.
            - candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores.
        Returns:
            - output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
            corresponding ``torch.distributions.Categorical`` object.

        Examples:
            >>> test_pids = [1]
            >>> cand_pids = [0, 2, 4]
            >>> problems = [ \
                "This is problem 0", "This is the first question", "Second problem is here", "Another problem", \
                "This is the last problem" \
            ]
            >>> ctxt_list = [problems[pid] for pid in test_pids]
            >>> cands_list = [problems[pid] for pid in cand_pids]
            >>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
            >>> scores = model(ctxt_list, cands_list)
            >>> assert scores.shape == (1, 3)
        """
        prompt_embedding = self._calc_embedding(train_samples)
        cands_embedding = self._calc_embedding(candidate_samples)
        scores = torch.mm(prompt_embedding, cands_embedding.t())
        return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}