kinensake's picture
Modify: requirements.txt
2ea9ced
raw
history blame
5.14 kB
#!/usr/bin/env python3
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
import argparse
import itertools
import os
import sys
import torch
from ..models.auto import AutoLMScorer as LMScorer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Get sentences probability using a language model.",
)
parser.add_argument(
"sentences_file_path",
metavar="sentences-file-path",
type=str,
help="A file containing sentences to score, one per line."
" If - is given as filename it reads from stdin instead.",
)
parser.add_argument(
"--model-name",
"-m",
type=str,
default="gpt2",
help="The pretrained language model to use. Can be one of: %s."
% ", ".join(LMScorer.supported_model_names()),
)
parser.add_argument(
"--tokens",
"-t",
action="store_true",
help="If provided it provides the probability of each token of each sentence.",
)
parser.add_argument(
"--log-prob",
"-lp",
action="store_true",
help="If provided log probabilities are returned instead.",
)
parser.add_argument(
"--reduce",
"-r",
type=str,
default="prod",
help="Reduce strategy applied on token probabilities to get the sentence score."
" Available strategies are: prod, mean, gmean, hmean.",
)
parser.add_argument(
"--batch-size",
"-b",
type=int,
default=1,
help="Number of sentences to process in parallel.",
)
parser.add_argument(
"--significant-figures",
"-sf",
type=int,
default=5,
help="Number of significant figures to use when printing numbers.",
)
parser.add_argument(
"--cuda",
type=int,
default=-1,
help="If provided it runs the model on the given cuda device.",
)
parser.add_argument(
"--debug",
action="store_true",
help="If provided it provides additional logging in case of errors.",
)
return parser.parse_args()
def normalize_args(args: argparse.Namespace) -> None:
if args.sentences_file_path != "-":
args.sentences_file_path = os.path.realpath(args.sentences_file_path)
def validate_args(args: argparse.Namespace) -> None:
if args.sentences_file_path != "-":
if not os.path.isfile(args.sentences_file_path):
raise ValueError("The provided sentences file path is invalid.")
if args.cuda >= 0 and not torch.cuda.is_available():
raise ValueError("No Cuda device found.")
if args.cuda >= torch.cuda.device_count():
device_count = torch.cuda.device_count()
raise ValueError("Invalid Cuda device: %d/%d." % (args.cuda, device_count))
if args.batch_size <= 0:
raise ValueError("The batch size must be positive.")
if args.significant_figures <= 0:
raise ValueError("The number of significant figures must be positive.")
T1 = TypeVar("T1") # pylint: disable=invalid-name
def grouper(iterable: Iterable[T1], size: int) -> Generator[List[T1], None, None]:
it = iter(iterable) # pylint: disable=invalid-name
while True:
chunk = list(itertools.islice(it, size))
if not chunk:
return
yield chunk
def main(args: argparse.Namespace) -> None:
# pylint: disable=too-many-locals
if args.sentences_file_path == "-":
sentences_stream = sys.stdin
else:
sentences_stream = open(args.sentences_file_path, "r")
sig_fig = args.significant_figures
batch_size = args.batch_size
device = torch.device("cuda:%d" % args.cuda if args.cuda >= 0 else "cpu")
scorer = LMScorer.from_pretrained(
args.model_name, device=device, batch_size=batch_size
)
buffer_size = args.batch_size * 2
for sentences in grouper(sentences_stream, buffer_size):
sentences = [sentence.strip() for sentence in sentences]
sent_scores = scorer.sentence_score(
sentences, log=args.log_prob, reduce=args.reduce
)
if args.tokens:
sent_info = scorer.tokens_score(sentences, log=args.log_prob)
sent_num = len(sentences)
for i in range(sent_num):
sentence, sent_score = sentences[i], sent_scores[i]
print(f"%s\t%.{sig_fig}g" % (sentence, sent_score))
if args.tokens:
scores, _, tokens = sent_info[i]
for score, token in zip(scores, tokens):
print(f"%s\t%.{sig_fig}g" % (token, score))
print("")
if args.sentences_file_path != "-":
sentences_stream.close()
def run() -> None:
try:
args = parse_args()
normalize_args(args)
validate_args(args)
main(args)
except KeyboardInterrupt:
print("\nAborted!")
except Exception as err: # pylint: disable=broad-except
if args.debug:
raise
print("Error: %s" % err)
if __name__ == "__main__":
run()