File size: 3,273 Bytes
2ea9ced |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
from abc import ABC, abstractmethod
import math
import torch
class LMScorer(ABC):
def __init__(self, model_name: str, **kwargs: Any) -> None:
self._build(model_name, kwargs)
@overload
def sentence_score(
self, text: str, log: bool = False, reduce: str = "prod"
) -> float:
...
@overload
def sentence_score(
self, text: List[str], log: bool = False, reduce: str = "prod"
) -> List[float]:
...
def sentence_score(
self, text: Union[str, List[str]], log: bool = False, reduce: str = "prod",
) -> Union[float, List[float]]:
sentences = [text] if isinstance(text, str) else text
scores: List[float] = []
if len(sentences) == 0:
return scores
outputs = self._tokens_log_prob(sentences)
for output in outputs:
log_probs = output[0]
tlen = log_probs.shape[0]
if reduce == "prod":
score = log_probs.sum()
elif reduce == "mean":
score = log_probs.logsumexp(0) - math.log(tlen)
elif reduce == "gmean":
score = log_probs.mean(0)
elif reduce == "hmean":
score = log_probs.neg().logsumexp(0).neg() + math.log(tlen)
else:
raise ValueError("Unrecognized scoring strategy: %s" % reduce)
if not log:
score = score.exp()
scores.append(score.item())
return scores[0] if isinstance(text, str) else scores
@overload
def tokens_score(
self, text: str, log: bool = False
) -> Tuple[List[float], List[int], List[str]]:
...
@overload
def tokens_score(
self, text: List[str], log: bool = False
) -> List[Tuple[List[float], List[int], List[str]]]:
...
def tokens_score(
self, text: Union[str, List[str]], log: bool = False
) -> Union[
Tuple[List[float], List[int], List[str]],
List[Tuple[List[float], List[int], List[str]]],
]:
sentences = [text] if isinstance(text, str) else text
outputs: List[Tuple[List[float], List[int], List[str]]] = []
if len(sentences) == 0:
return outputs
for log_probs, ids, tokens in self._tokens_log_prob(sentences):
scores = log_probs if log else log_probs.exp()
scores = cast(torch.DoubleTensor, scores)
output = (scores.tolist(), ids.tolist(), tokens)
outputs.append(output)
return outputs[0] if isinstance(text, str) else outputs
@classmethod
def supported_model_names(cls) -> Iterable[str]:
return cls._supported_model_names()
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
# pylint: disable=attribute-defined-outside-init, unused-argument
self.model_name = model_name
@abstractmethod
def _tokens_log_prob(
self, text: List[str]
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
... # pragma: no cover
@classmethod
@abstractmethod
def _supported_model_names(cls) -> Iterable[str]:
... # pragma: no cover
|