Modify: requirements.txt
Browse files- lm_scorer/__init__.py +0 -0
- lm_scorer/bin/__init__.py +0 -0
- lm_scorer/bin/cli.py +172 -0
- lm_scorer/models/__init__.py +0 -0
- lm_scorer/models/abc/__init__.py +0 -0
- lm_scorer/models/abc/base.py +103 -0
- lm_scorer/models/abc/batch.py +35 -0
- lm_scorer/models/abc/transformers.py +16 -0
- lm_scorer/models/auto.py +34 -0
- lm_scorer/models/gpt2.py +85 -0
- requirements.txt +0 -1
lm_scorer/__init__.py
ADDED
File without changes
|
lm_scorer/bin/__init__.py
ADDED
File without changes
|
lm_scorer/bin/cli.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
import itertools
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from ..models.auto import AutoLMScorer as LMScorer
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args() -> argparse.Namespace:
|
16 |
+
parser = argparse.ArgumentParser(
|
17 |
+
description="Get sentences probability using a language model.",
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"sentences_file_path",
|
21 |
+
metavar="sentences-file-path",
|
22 |
+
type=str,
|
23 |
+
help="A file containing sentences to score, one per line."
|
24 |
+
" If - is given as filename it reads from stdin instead.",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--model-name",
|
28 |
+
"-m",
|
29 |
+
type=str,
|
30 |
+
default="gpt2",
|
31 |
+
help="The pretrained language model to use. Can be one of: %s."
|
32 |
+
% ", ".join(LMScorer.supported_model_names()),
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--tokens",
|
36 |
+
"-t",
|
37 |
+
action="store_true",
|
38 |
+
help="If provided it provides the probability of each token of each sentence.",
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--log-prob",
|
42 |
+
"-lp",
|
43 |
+
action="store_true",
|
44 |
+
help="If provided log probabilities are returned instead.",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--reduce",
|
48 |
+
"-r",
|
49 |
+
type=str,
|
50 |
+
default="prod",
|
51 |
+
help="Reduce strategy applied on token probabilities to get the sentence score."
|
52 |
+
" Available strategies are: prod, mean, gmean, hmean.",
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--batch-size",
|
56 |
+
"-b",
|
57 |
+
type=int,
|
58 |
+
default=1,
|
59 |
+
help="Number of sentences to process in parallel.",
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--significant-figures",
|
63 |
+
"-sf",
|
64 |
+
type=int,
|
65 |
+
default=5,
|
66 |
+
help="Number of significant figures to use when printing numbers.",
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--cuda",
|
70 |
+
type=int,
|
71 |
+
default=-1,
|
72 |
+
help="If provided it runs the model on the given cuda device.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--debug",
|
76 |
+
action="store_true",
|
77 |
+
help="If provided it provides additional logging in case of errors.",
|
78 |
+
)
|
79 |
+
return parser.parse_args()
|
80 |
+
|
81 |
+
|
82 |
+
def normalize_args(args: argparse.Namespace) -> None:
|
83 |
+
if args.sentences_file_path != "-":
|
84 |
+
args.sentences_file_path = os.path.realpath(args.sentences_file_path)
|
85 |
+
|
86 |
+
|
87 |
+
def validate_args(args: argparse.Namespace) -> None:
|
88 |
+
if args.sentences_file_path != "-":
|
89 |
+
if not os.path.isfile(args.sentences_file_path):
|
90 |
+
raise ValueError("The provided sentences file path is invalid.")
|
91 |
+
|
92 |
+
if args.cuda >= 0 and not torch.cuda.is_available():
|
93 |
+
raise ValueError("No Cuda device found.")
|
94 |
+
|
95 |
+
if args.cuda >= torch.cuda.device_count():
|
96 |
+
device_count = torch.cuda.device_count()
|
97 |
+
raise ValueError("Invalid Cuda device: %d/%d." % (args.cuda, device_count))
|
98 |
+
|
99 |
+
if args.batch_size <= 0:
|
100 |
+
raise ValueError("The batch size must be positive.")
|
101 |
+
|
102 |
+
if args.significant_figures <= 0:
|
103 |
+
raise ValueError("The number of significant figures must be positive.")
|
104 |
+
|
105 |
+
|
106 |
+
T1 = TypeVar("T1") # pylint: disable=invalid-name
|
107 |
+
|
108 |
+
|
109 |
+
def grouper(iterable: Iterable[T1], size: int) -> Generator[List[T1], None, None]:
|
110 |
+
it = iter(iterable) # pylint: disable=invalid-name
|
111 |
+
while True:
|
112 |
+
chunk = list(itertools.islice(it, size))
|
113 |
+
if not chunk:
|
114 |
+
return
|
115 |
+
yield chunk
|
116 |
+
|
117 |
+
|
118 |
+
def main(args: argparse.Namespace) -> None:
|
119 |
+
# pylint: disable=too-many-locals
|
120 |
+
if args.sentences_file_path == "-":
|
121 |
+
sentences_stream = sys.stdin
|
122 |
+
else:
|
123 |
+
sentences_stream = open(args.sentences_file_path, "r")
|
124 |
+
|
125 |
+
sig_fig = args.significant_figures
|
126 |
+
batch_size = args.batch_size
|
127 |
+
device = torch.device("cuda:%d" % args.cuda if args.cuda >= 0 else "cpu")
|
128 |
+
scorer = LMScorer.from_pretrained(
|
129 |
+
args.model_name, device=device, batch_size=batch_size
|
130 |
+
)
|
131 |
+
|
132 |
+
buffer_size = args.batch_size * 2
|
133 |
+
for sentences in grouper(sentences_stream, buffer_size):
|
134 |
+
sentences = [sentence.strip() for sentence in sentences]
|
135 |
+
|
136 |
+
sent_scores = scorer.sentence_score(
|
137 |
+
sentences, log=args.log_prob, reduce=args.reduce
|
138 |
+
)
|
139 |
+
if args.tokens:
|
140 |
+
sent_info = scorer.tokens_score(sentences, log=args.log_prob)
|
141 |
+
|
142 |
+
sent_num = len(sentences)
|
143 |
+
for i in range(sent_num):
|
144 |
+
sentence, sent_score = sentences[i], sent_scores[i]
|
145 |
+
print(f"%s\t%.{sig_fig}g" % (sentence, sent_score))
|
146 |
+
if args.tokens:
|
147 |
+
scores, _, tokens = sent_info[i]
|
148 |
+
for score, token in zip(scores, tokens):
|
149 |
+
print(f"%s\t%.{sig_fig}g" % (token, score))
|
150 |
+
print("")
|
151 |
+
|
152 |
+
if args.sentences_file_path != "-":
|
153 |
+
sentences_stream.close()
|
154 |
+
|
155 |
+
|
156 |
+
def run() -> None:
|
157 |
+
try:
|
158 |
+
args = parse_args()
|
159 |
+
|
160 |
+
normalize_args(args)
|
161 |
+
validate_args(args)
|
162 |
+
main(args)
|
163 |
+
except KeyboardInterrupt:
|
164 |
+
print("\nAborted!")
|
165 |
+
except Exception as err: # pylint: disable=broad-except
|
166 |
+
if args.debug:
|
167 |
+
raise
|
168 |
+
print("Error: %s" % err)
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
run()
|
lm_scorer/models/__init__.py
ADDED
File without changes
|
lm_scorer/models/abc/__init__.py
ADDED
File without changes
|
lm_scorer/models/abc/base.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class LMScorer(ABC):
|
10 |
+
def __init__(self, model_name: str, **kwargs: Any) -> None:
|
11 |
+
self._build(model_name, kwargs)
|
12 |
+
|
13 |
+
@overload
|
14 |
+
def sentence_score(
|
15 |
+
self, text: str, log: bool = False, reduce: str = "prod"
|
16 |
+
) -> float:
|
17 |
+
...
|
18 |
+
|
19 |
+
@overload
|
20 |
+
def sentence_score(
|
21 |
+
self, text: List[str], log: bool = False, reduce: str = "prod"
|
22 |
+
) -> List[float]:
|
23 |
+
...
|
24 |
+
|
25 |
+
def sentence_score(
|
26 |
+
self, text: Union[str, List[str]], log: bool = False, reduce: str = "prod",
|
27 |
+
) -> Union[float, List[float]]:
|
28 |
+
sentences = [text] if isinstance(text, str) else text
|
29 |
+
scores: List[float] = []
|
30 |
+
if len(sentences) == 0:
|
31 |
+
return scores
|
32 |
+
|
33 |
+
outputs = self._tokens_log_prob(sentences)
|
34 |
+
for output in outputs:
|
35 |
+
log_probs = output[0]
|
36 |
+
tlen = log_probs.shape[0]
|
37 |
+
|
38 |
+
if reduce == "prod":
|
39 |
+
score = log_probs.sum()
|
40 |
+
elif reduce == "mean":
|
41 |
+
score = log_probs.logsumexp(0) - math.log(tlen)
|
42 |
+
elif reduce == "gmean":
|
43 |
+
score = log_probs.mean(0)
|
44 |
+
elif reduce == "hmean":
|
45 |
+
score = log_probs.neg().logsumexp(0).neg() + math.log(tlen)
|
46 |
+
else:
|
47 |
+
raise ValueError("Unrecognized scoring strategy: %s" % reduce)
|
48 |
+
if not log:
|
49 |
+
score = score.exp()
|
50 |
+
|
51 |
+
scores.append(score.item())
|
52 |
+
|
53 |
+
return scores[0] if isinstance(text, str) else scores
|
54 |
+
|
55 |
+
@overload
|
56 |
+
def tokens_score(
|
57 |
+
self, text: str, log: bool = False
|
58 |
+
) -> Tuple[List[float], List[int], List[str]]:
|
59 |
+
...
|
60 |
+
|
61 |
+
@overload
|
62 |
+
def tokens_score(
|
63 |
+
self, text: List[str], log: bool = False
|
64 |
+
) -> List[Tuple[List[float], List[int], List[str]]]:
|
65 |
+
...
|
66 |
+
|
67 |
+
def tokens_score(
|
68 |
+
self, text: Union[str, List[str]], log: bool = False
|
69 |
+
) -> Union[
|
70 |
+
Tuple[List[float], List[int], List[str]],
|
71 |
+
List[Tuple[List[float], List[int], List[str]]],
|
72 |
+
]:
|
73 |
+
sentences = [text] if isinstance(text, str) else text
|
74 |
+
outputs: List[Tuple[List[float], List[int], List[str]]] = []
|
75 |
+
if len(sentences) == 0:
|
76 |
+
return outputs
|
77 |
+
|
78 |
+
for log_probs, ids, tokens in self._tokens_log_prob(sentences):
|
79 |
+
scores = log_probs if log else log_probs.exp()
|
80 |
+
scores = cast(torch.DoubleTensor, scores)
|
81 |
+
output = (scores.tolist(), ids.tolist(), tokens)
|
82 |
+
outputs.append(output)
|
83 |
+
|
84 |
+
return outputs[0] if isinstance(text, str) else outputs
|
85 |
+
|
86 |
+
@classmethod
|
87 |
+
def supported_model_names(cls) -> Iterable[str]:
|
88 |
+
return cls._supported_model_names()
|
89 |
+
|
90 |
+
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
|
91 |
+
# pylint: disable=attribute-defined-outside-init, unused-argument
|
92 |
+
self.model_name = model_name
|
93 |
+
|
94 |
+
@abstractmethod
|
95 |
+
def _tokens_log_prob(
|
96 |
+
self, text: List[str]
|
97 |
+
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
|
98 |
+
... # pragma: no cover
|
99 |
+
|
100 |
+
@classmethod
|
101 |
+
@abstractmethod
|
102 |
+
def _supported_model_names(cls) -> Iterable[str]:
|
103 |
+
... # pragma: no cover
|
lm_scorer/models/abc/batch.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=abstract-method
|
2 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
3 |
+
from abc import abstractmethod
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from .base import LMScorer
|
8 |
+
|
9 |
+
|
10 |
+
class BatchedLMScorer(LMScorer):
|
11 |
+
# @overrides
|
12 |
+
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
|
13 |
+
super()._build(model_name, options)
|
14 |
+
|
15 |
+
batch_size = options.get("batch_size", 1)
|
16 |
+
if batch_size < 1:
|
17 |
+
raise ValueError("The batch_size option must be positive")
|
18 |
+
# pylint: disable=attribute-defined-outside-init
|
19 |
+
self.batch_size = batch_size
|
20 |
+
|
21 |
+
# @overrides
|
22 |
+
def _tokens_log_prob(
|
23 |
+
self, text: List[str]
|
24 |
+
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
|
25 |
+
outputs = []
|
26 |
+
for i in range(0, len(text), self.batch_size):
|
27 |
+
batch = text[i : i + self.batch_size]
|
28 |
+
outputs.extend(self._tokens_log_prob_for_batch(batch))
|
29 |
+
return outputs
|
30 |
+
|
31 |
+
@abstractmethod
|
32 |
+
def _tokens_log_prob_for_batch(
|
33 |
+
self, text: List[str]
|
34 |
+
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
|
35 |
+
... # pragma: no cover
|
lm_scorer/models/abc/transformers.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=abstract-method
|
2 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
from .batch import BatchedLMScorer
|
7 |
+
|
8 |
+
|
9 |
+
class TransformersLMScorer(BatchedLMScorer):
|
10 |
+
# @overrides
|
11 |
+
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
|
12 |
+
super()._build(model_name, options)
|
13 |
+
|
14 |
+
# Make transformers cache path configurable.
|
15 |
+
cache_dir = os.environ.get("TRANSFORMERS_CACHE_DIR", ".transformers_cache")
|
16 |
+
options["cache_dir"] = options.get("cache_dir", cache_dir)
|
lm_scorer/models/auto.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
2 |
+
|
3 |
+
import itertools
|
4 |
+
|
5 |
+
from .abc.base import LMScorer
|
6 |
+
from .gpt2 import GPT2LMScorer
|
7 |
+
|
8 |
+
|
9 |
+
class AutoLMScorer:
|
10 |
+
MODEL_CLASSES = [GPT2LMScorer]
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
raise EnvironmentError(
|
14 |
+
"AutoLMscorer is designed to be instantiated "
|
15 |
+
"using the `AutoLMscorer.from_pretrained(model_name)`"
|
16 |
+
"method"
|
17 |
+
)
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def from_pretrained(cls, model_name: str, **kwargs: Any) -> LMScorer:
|
21 |
+
for model_class in cls.MODEL_CLASSES:
|
22 |
+
if model_name not in model_class.supported_model_names():
|
23 |
+
continue
|
24 |
+
return model_class(model_name, **kwargs)
|
25 |
+
raise ValueError(
|
26 |
+
"Unrecognized model name."
|
27 |
+
"Can be one of: %s" % ", ".join(cls.supported_model_names()),
|
28 |
+
)
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def supported_model_names(cls) -> Iterable[str]:
|
32 |
+
classes = cls.MODEL_CLASSES
|
33 |
+
models = map(lambda c: c.supported_model_names(), classes)
|
34 |
+
return itertools.chain.from_iterable(models)
|
lm_scorer/models/gpt2.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
|
2 |
+
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, GPT2LMHeadModel
|
6 |
+
from transformers import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
7 |
+
from transformers.tokenization_utils import BatchEncoding
|
8 |
+
|
9 |
+
from .abc.transformers import TransformersLMScorer
|
10 |
+
|
11 |
+
|
12 |
+
class GPT2LMScorer(TransformersLMScorer):
|
13 |
+
# @overrides
|
14 |
+
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
|
15 |
+
super()._build(model_name, options)
|
16 |
+
|
17 |
+
# pylint: disable=attribute-defined-outside-init
|
18 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
19 |
+
model_name, use_fast=True, add_special_tokens=False
|
20 |
+
)
|
21 |
+
# Add the pad token to GPT2 dictionary.
|
22 |
+
# len(tokenizer) = vocab_size + 1
|
23 |
+
self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]})
|
24 |
+
self.tokenizer.pad_token = "<|pad|>"
|
25 |
+
|
26 |
+
self.model = GPT2LMHeadModel.from_pretrained(model_name)
|
27 |
+
# We need to resize the embedding layer because we added the pad token.
|
28 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
29 |
+
self.model.eval()
|
30 |
+
if "device" in options:
|
31 |
+
self.model.to(options["device"])
|
32 |
+
|
33 |
+
def _add_special_tokens(self, text: str) -> str:
|
34 |
+
return self.tokenizer.bos_token + text + self.tokenizer.eos_token
|
35 |
+
|
36 |
+
# @overrides
|
37 |
+
def _tokens_log_prob_for_batch(
|
38 |
+
self, text: List[str]
|
39 |
+
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
|
40 |
+
outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = []
|
41 |
+
if len(text) == 0:
|
42 |
+
return outputs
|
43 |
+
|
44 |
+
# TODO: Handle overflowing elements for long sentences
|
45 |
+
text = list(map(self._add_special_tokens, text))
|
46 |
+
encoding: BatchEncoding = self.tokenizer.batch_encode_plus(
|
47 |
+
text, return_tensors="pt",
|
48 |
+
)
|
49 |
+
with torch.no_grad():
|
50 |
+
ids = encoding["input_ids"].to(self.model.device)
|
51 |
+
attention_mask = encoding["attention_mask"].to(self.model.device)
|
52 |
+
nopad_mask = ids != self.tokenizer.pad_token_id
|
53 |
+
logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0]
|
54 |
+
|
55 |
+
for sent_index in range(len(text)):
|
56 |
+
sent_nopad_mask = nopad_mask[sent_index]
|
57 |
+
# len(tokens) = len(text[sent_index]) + 1
|
58 |
+
sent_tokens = [
|
59 |
+
tok
|
60 |
+
for i, tok in enumerate(encoding.tokens(sent_index))
|
61 |
+
if sent_nopad_mask[i] and i != 0
|
62 |
+
]
|
63 |
+
|
64 |
+
# sent_ids.shape = [len(text[sent_index]) + 1]
|
65 |
+
sent_ids = ids[sent_index, sent_nopad_mask][1:]
|
66 |
+
# logits.shape = [len(text[sent_index]) + 1, vocab_size]
|
67 |
+
sent_logits = logits[sent_index, sent_nopad_mask][:-1, :]
|
68 |
+
sent_logits[:, self.tokenizer.pad_token_id] = float("-inf")
|
69 |
+
# ids_scores.shape = [seq_len + 1]
|
70 |
+
sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1)
|
71 |
+
# log_prob.shape = [seq_len + 1]
|
72 |
+
sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1)
|
73 |
+
|
74 |
+
sent_log_probs = cast(torch.DoubleTensor, sent_log_probs)
|
75 |
+
sent_ids = cast(torch.LongTensor, sent_ids)
|
76 |
+
|
77 |
+
output = (sent_log_probs, sent_ids, sent_tokens)
|
78 |
+
outputs.append(output)
|
79 |
+
|
80 |
+
return outputs
|
81 |
+
|
82 |
+
# @overrides
|
83 |
+
@classmethod
|
84 |
+
def _supported_model_names(cls) -> Iterable[str]:
|
85 |
+
return GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()
|
requirements.txt
CHANGED
@@ -6,5 +6,4 @@ python-Levenshtein==0.12.2
|
|
6 |
fuzzywuzzy==0.18.0
|
7 |
tokenizers==0.10.2
|
8 |
fsspec==2021.5.0
|
9 |
-
lm-scorer==0.4.2 --install-option='--ignore-requires-python'
|
10 |
errant
|
|
|
6 |
fuzzywuzzy==0.18.0
|
7 |
tokenizers==0.10.2
|
8 |
fsspec==2021.5.0
|
|
|
9 |
errant
|