Spaces:
Running
on
T4
Running
on
T4
import time | |
import warnings | |
from abc import ABC | |
from copy import deepcopy | |
from typing import Optional | |
import torch | |
from ..utils import add_start_docstrings, logging | |
logger = logging.get_logger(__name__) | |
STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): | |
Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax | |
or scores for each vocabulary token after SoftMax. | |
kwargs (`Dict[str, Any]`, *optional*): | |
Additional stopping criteria specific kwargs. | |
Return: | |
`bool`. `False` indicates we should continue, `True` indicates we should stop. | |
""" | |
class StoppingCriteria(ABC): | |
"""Abstract base class for all stopping criteria that can be applied during generation.""" | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
raise NotImplementedError("StoppingCriteria needs to be subclassed") | |
class MaxLengthCriteria(StoppingCriteria): | |
""" | |
This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep | |
in mind for decoder-only type of transformers, this will include the initial prompted tokens. | |
Args: | |
max_length (`int`): | |
The maximum length that the output sequence can have in number of tokens. | |
max_position_embeddings (`int`, *optional*): | |
The maximum model length, as defined by the model's `config.max_position_embeddings` attribute. | |
""" | |
def __init__(self, max_length: int, max_position_embeddings: Optional[int] = None): | |
self.max_length = max_length | |
self.max_position_embeddings = max_position_embeddings | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
cur_len = input_ids.shape[-1] | |
is_done = cur_len >= self.max_length | |
if self.max_position_embeddings is not None and not is_done and cur_len >= self.max_position_embeddings: | |
logger.warning_once( | |
"This is a friendly reminder - the current text generation call will exceed the model's predefined " | |
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe " | |
"exceptions, performance degradation, or nothing at all." | |
) | |
return is_done | |
class MaxNewTokensCriteria(StoppingCriteria): | |
""" | |
This class can be used to stop generation whenever the generated number of tokens exceeds `max_new_tokens`. Keep in | |
mind for decoder-only type of transformers, this will **not** include the initial prompted tokens. This is very | |
close to `MaxLengthCriteria` but ignores the number of initial tokens. | |
Args: | |
start_length (`int`): | |
The number of initial tokens. | |
max_new_tokens (`int`): | |
The maximum number of tokens to generate. | |
""" | |
def __init__(self, start_length: int, max_new_tokens: int): | |
warnings.warn( | |
"The class `MaxNewTokensCriteria` is deprecated. " | |
f"Please use `MaxLengthCriteria(max_length={start_length + max_new_tokens})` " | |
"with `max_length = start_length + max_new_tokens` instead.", | |
FutureWarning, | |
) | |
self.start_length = start_length | |
self.max_new_tokens = max_new_tokens | |
self.max_length = start_length + max_new_tokens | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return input_ids.shape[-1] >= self.max_length | |
class MaxTimeCriteria(StoppingCriteria): | |
""" | |
This class can be used to stop generation whenever the full generation exceeds some amount of time. By default, the | |
time will start being counted when you initialize this function. You can override this by passing an | |
`initial_time`. | |
Args: | |
max_time (`float`): | |
The maximum allowed time in seconds for the generation. | |
initial_time (`float`, *optional*, defaults to `time.time()`): | |
The start of the generation allowed time. | |
""" | |
def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): | |
self.max_time = max_time | |
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return time.time() - self.initial_timestamp > self.max_time | |
class StoppingCriteriaList(list): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return any(criteria(input_ids, scores) for criteria in self) | |
def max_length(self) -> Optional[int]: | |
for stopping_criterium in self: | |
if isinstance(stopping_criterium, MaxLengthCriteria): | |
return stopping_criterium.max_length | |
elif isinstance(stopping_criterium, MaxNewTokensCriteria): | |
return stopping_criterium.max_length | |
return None | |
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: | |
stopping_max_length = stopping_criteria.max_length | |
new_stopping_criteria = deepcopy(stopping_criteria) | |
if stopping_max_length is not None and stopping_max_length != max_length: | |
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | |
elif stopping_max_length is None: | |
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | |
return new_stopping_criteria | |