kinensake's picture
Modify: requirements.txt
2ea9ced
raw
history blame
3.27 kB
from typing import * # pylint: disable=wildcard-import,unused-wildcard-import
from abc import ABC, abstractmethod
import math
import torch
class LMScorer(ABC):
def __init__(self, model_name: str, **kwargs: Any) -> None:
self._build(model_name, kwargs)
@overload
def sentence_score(
self, text: str, log: bool = False, reduce: str = "prod"
) -> float:
...
@overload
def sentence_score(
self, text: List[str], log: bool = False, reduce: str = "prod"
) -> List[float]:
...
def sentence_score(
self, text: Union[str, List[str]], log: bool = False, reduce: str = "prod",
) -> Union[float, List[float]]:
sentences = [text] if isinstance(text, str) else text
scores: List[float] = []
if len(sentences) == 0:
return scores
outputs = self._tokens_log_prob(sentences)
for output in outputs:
log_probs = output[0]
tlen = log_probs.shape[0]
if reduce == "prod":
score = log_probs.sum()
elif reduce == "mean":
score = log_probs.logsumexp(0) - math.log(tlen)
elif reduce == "gmean":
score = log_probs.mean(0)
elif reduce == "hmean":
score = log_probs.neg().logsumexp(0).neg() + math.log(tlen)
else:
raise ValueError("Unrecognized scoring strategy: %s" % reduce)
if not log:
score = score.exp()
scores.append(score.item())
return scores[0] if isinstance(text, str) else scores
@overload
def tokens_score(
self, text: str, log: bool = False
) -> Tuple[List[float], List[int], List[str]]:
...
@overload
def tokens_score(
self, text: List[str], log: bool = False
) -> List[Tuple[List[float], List[int], List[str]]]:
...
def tokens_score(
self, text: Union[str, List[str]], log: bool = False
) -> Union[
Tuple[List[float], List[int], List[str]],
List[Tuple[List[float], List[int], List[str]]],
]:
sentences = [text] if isinstance(text, str) else text
outputs: List[Tuple[List[float], List[int], List[str]]] = []
if len(sentences) == 0:
return outputs
for log_probs, ids, tokens in self._tokens_log_prob(sentences):
scores = log_probs if log else log_probs.exp()
scores = cast(torch.DoubleTensor, scores)
output = (scores.tolist(), ids.tolist(), tokens)
outputs.append(output)
return outputs[0] if isinstance(text, str) else outputs
@classmethod
def supported_model_names(cls) -> Iterable[str]:
return cls._supported_model_names()
def _build(self, model_name: str, options: Dict[str, Any]) -> None:
# pylint: disable=attribute-defined-outside-init, unused-argument
self.model_name = model_name
@abstractmethod
def _tokens_log_prob(
self, text: List[str]
) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]:
... # pragma: no cover
@classmethod
@abstractmethod
def _supported_model_names(cls) -> Iterable[str]:
... # pragma: no cover