|
|
|
|
|
from typing import * |
|
|
|
import argparse |
|
import itertools |
|
import os |
|
import sys |
|
|
|
import torch |
|
|
|
from ..models.auto import AutoLMScorer as LMScorer |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
parser = argparse.ArgumentParser( |
|
description="Get sentences probability using a language model.", |
|
) |
|
parser.add_argument( |
|
"sentences_file_path", |
|
metavar="sentences-file-path", |
|
type=str, |
|
help="A file containing sentences to score, one per line." |
|
" If - is given as filename it reads from stdin instead.", |
|
) |
|
parser.add_argument( |
|
"--model-name", |
|
"-m", |
|
type=str, |
|
default="gpt2", |
|
help="The pretrained language model to use. Can be one of: %s." |
|
% ", ".join(LMScorer.supported_model_names()), |
|
) |
|
parser.add_argument( |
|
"--tokens", |
|
"-t", |
|
action="store_true", |
|
help="If provided it provides the probability of each token of each sentence.", |
|
) |
|
parser.add_argument( |
|
"--log-prob", |
|
"-lp", |
|
action="store_true", |
|
help="If provided log probabilities are returned instead.", |
|
) |
|
parser.add_argument( |
|
"--reduce", |
|
"-r", |
|
type=str, |
|
default="prod", |
|
help="Reduce strategy applied on token probabilities to get the sentence score." |
|
" Available strategies are: prod, mean, gmean, hmean.", |
|
) |
|
parser.add_argument( |
|
"--batch-size", |
|
"-b", |
|
type=int, |
|
default=1, |
|
help="Number of sentences to process in parallel.", |
|
) |
|
parser.add_argument( |
|
"--significant-figures", |
|
"-sf", |
|
type=int, |
|
default=5, |
|
help="Number of significant figures to use when printing numbers.", |
|
) |
|
parser.add_argument( |
|
"--cuda", |
|
type=int, |
|
default=-1, |
|
help="If provided it runs the model on the given cuda device.", |
|
) |
|
parser.add_argument( |
|
"--debug", |
|
action="store_true", |
|
help="If provided it provides additional logging in case of errors.", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
def normalize_args(args: argparse.Namespace) -> None: |
|
if args.sentences_file_path != "-": |
|
args.sentences_file_path = os.path.realpath(args.sentences_file_path) |
|
|
|
|
|
def validate_args(args: argparse.Namespace) -> None: |
|
if args.sentences_file_path != "-": |
|
if not os.path.isfile(args.sentences_file_path): |
|
raise ValueError("The provided sentences file path is invalid.") |
|
|
|
if args.cuda >= 0 and not torch.cuda.is_available(): |
|
raise ValueError("No Cuda device found.") |
|
|
|
if args.cuda >= torch.cuda.device_count(): |
|
device_count = torch.cuda.device_count() |
|
raise ValueError("Invalid Cuda device: %d/%d." % (args.cuda, device_count)) |
|
|
|
if args.batch_size <= 0: |
|
raise ValueError("The batch size must be positive.") |
|
|
|
if args.significant_figures <= 0: |
|
raise ValueError("The number of significant figures must be positive.") |
|
|
|
|
|
T1 = TypeVar("T1") |
|
|
|
|
|
def grouper(iterable: Iterable[T1], size: int) -> Generator[List[T1], None, None]: |
|
it = iter(iterable) |
|
while True: |
|
chunk = list(itertools.islice(it, size)) |
|
if not chunk: |
|
return |
|
yield chunk |
|
|
|
|
|
def main(args: argparse.Namespace) -> None: |
|
|
|
if args.sentences_file_path == "-": |
|
sentences_stream = sys.stdin |
|
else: |
|
sentences_stream = open(args.sentences_file_path, "r") |
|
|
|
sig_fig = args.significant_figures |
|
batch_size = args.batch_size |
|
device = torch.device("cuda:%d" % args.cuda if args.cuda >= 0 else "cpu") |
|
scorer = LMScorer.from_pretrained( |
|
args.model_name, device=device, batch_size=batch_size |
|
) |
|
|
|
buffer_size = args.batch_size * 2 |
|
for sentences in grouper(sentences_stream, buffer_size): |
|
sentences = [sentence.strip() for sentence in sentences] |
|
|
|
sent_scores = scorer.sentence_score( |
|
sentences, log=args.log_prob, reduce=args.reduce |
|
) |
|
if args.tokens: |
|
sent_info = scorer.tokens_score(sentences, log=args.log_prob) |
|
|
|
sent_num = len(sentences) |
|
for i in range(sent_num): |
|
sentence, sent_score = sentences[i], sent_scores[i] |
|
print(f"%s\t%.{sig_fig}g" % (sentence, sent_score)) |
|
if args.tokens: |
|
scores, _, tokens = sent_info[i] |
|
for score, token in zip(scores, tokens): |
|
print(f"%s\t%.{sig_fig}g" % (token, score)) |
|
print("") |
|
|
|
if args.sentences_file_path != "-": |
|
sentences_stream.close() |
|
|
|
|
|
def run() -> None: |
|
try: |
|
args = parse_args() |
|
|
|
normalize_args(args) |
|
validate_args(args) |
|
main(args) |
|
except KeyboardInterrupt: |
|
print("\nAborted!") |
|
except Exception as err: |
|
if args.debug: |
|
raise |
|
print("Error: %s" % err) |
|
|
|
|
|
if __name__ == "__main__": |
|
run() |
|
|