|
import copy |
|
import importlib.metadata |
|
import json |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from packaging import version |
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import is_torchdynamo_compiling, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Cache(torch.nn.Module): |
|
""" |
|
Base, abstract class for all caches. The actual data structure is specific to each subclass. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of |
|
cache to be created. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
raise NotImplementedError("Make sure to implement `update` in a subclass.") |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|
|
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states, if there is any.""" |
|
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") |
|
|
|
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: |
|
"""Given the sequence length of the new inputs, returns the usable length of the cache.""" |
|
|
|
|
|
|
|
max_length = self.get_max_length() |
|
previous_seq_length = self.get_seq_length(layer_idx) |
|
if max_length is not None and previous_seq_length + new_seq_length > max_length: |
|
return max_length - new_seq_length |
|
return previous_seq_length |
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor): |
|
"""Reorders the cache for beam search, given the selected beam indices.""" |
|
for layer_idx in range(len(self.key_cache)): |
|
device = self.key_cache[layer_idx].device |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
device = self.value_cache[layer_idx].device |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
|
|
|
@property |
|
def seen_tokens(self): |
|
logger.warning_once( |
|
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " |
|
"model input instead." |
|
) |
|
if hasattr(self, "_seen_tokens"): |
|
return self._seen_tokens |
|
else: |
|
return None |
|
|
|
|
|
@dataclass |
|
class CacheConfig: |
|
""" |
|
Base class for cache configs |
|
""" |
|
|
|
cache_implementation: None |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict, **kwargs): |
|
""" |
|
Constructs a CacheConfig instance from a dictionary of parameters. |
|
Args: |
|
config_dict (Dict[str, Any]): Dictionary containing configuration parameters. |
|
**kwargs: Additional keyword arguments to override dictionary values. |
|
|
|
Returns: |
|
CacheConfig: Instance of CacheConfig constructed from the dictionary. |
|
""" |
|
config = cls(**config_dict) |
|
to_remove = [] |
|
for key, value in kwargs.items(): |
|
if hasattr(config, key): |
|
setattr(config, key, value) |
|
to_remove.append(key) |
|
for key in to_remove: |
|
kwargs.pop(key, None) |
|
return config |
|
|
|
|
|
def to_json_file(self, json_file_path: Union[str, os.PathLike]): |
|
""" |
|
Save this instance to a JSON file. |
|
|
|
Args: |
|
json_file_path (`str` or `os.PathLike`): |
|
Path to the JSON file in which this configuration instance's parameters will be saved. |
|
use_diff (`bool`, *optional*, defaults to `True`): |
|
If set to `True`, only the difference between the config instance and the default |
|
`QuantizationConfig()` is serialized to JSON file. |
|
""" |
|
with open(json_file_path, "w", encoding="utf-8") as writer: |
|
config_dict = self.to_dict() |
|
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" |
|
|
|
writer.write(json_string) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
""" |
|
Serializes this instance to a Python dictionary. Returns: |
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. |
|
""" |
|
return copy.deepcopy(self.__dict__) |
|
|
|
|
|
def __iter__(self): |
|
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" |
|
for attr, value in copy.deepcopy(self.__dict__).items(): |
|
yield attr, value |
|
|
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__} {self.to_json_string()}" |
|
|
|
def to_json_string(self): |
|
""" |
|
Serializes this instance to a JSON formatted string. |
|
Returns: |
|
str: JSON formatted string representing the configuration instance. |
|
""" |
|
return json.dumps(self.__dict__, indent=2) + "\n" |
|
|
|
|
|
def update(self, **kwargs): |
|
""" |
|
Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, |
|
returning all the unused kwargs. |
|
|
|
Args: |
|
kwargs (`Dict[str, Any]`): |
|
Dictionary of attributes to tentatively update this class. |
|
|
|
Returns: |
|
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. |
|
""" |
|
to_remove = [] |
|
for key, value in kwargs.items(): |
|
if hasattr(self, key): |
|
setattr(self, key, value) |
|
to_remove.append(key) |
|
|
|
|
|
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} |
|
return unused_kwargs |
|
|
|
|
|
class StaticCache(Cache): |
|
""" |
|
Static Cache class to be used with `torch.compile(model)` and `torch.export()`. |
|
|
|
Parameters: |
|
config (`PretrainedConfig`): |
|
The configuration file defining the shape-related attributes required to initialize the static cache. |
|
max_batch_size (`int`): |
|
The maximum batch size with which the model will be used. |
|
max_cache_len (`int`): |
|
The maximum sequence length with which the model will be used. |
|
device (`torch.device`): |
|
The device on which the cache should be initialized. Should be the same as the layer. |
|
dtype (*optional*, defaults to `torch.float32`): |
|
The default `dtype` to use when initializing the layer. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache |
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") |
|
|
|
>>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt") |
|
|
|
>>> # Prepare a cache class and pass it to model's forward |
|
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate |
|
>>> max_generated_length = inputs.input_ids.shape[1] + 10 |
|
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) |
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
|
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation |
|
``` |
|
""" |
|
|
|
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: |
|
super().__init__() |
|
self.max_batch_size = max_batch_size |
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len |
|
|
|
self.head_dim = ( |
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads |
|
) |
|
|
|
self.dtype = dtype if dtype is not None else torch.float32 |
|
self.num_key_value_heads = ( |
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads |
|
) |
|
|
|
self.key_cache: List[torch.Tensor] = [] |
|
self.value_cache: List[torch.Tensor] = [] |
|
|
|
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) |
|
for idx in range(config.num_hidden_layers): |
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) |
|
|
|
|
|
|
|
|
|
|
|
if not is_torchdynamo_compiling(): |
|
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) |
|
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=device)) |
|
new_layer_key_cache = getattr(self, f"key_cache_{idx}") |
|
new_layer_value_cache = getattr(self, f"value_cache_{idx}") |
|
torch._dynamo.mark_static_address(new_layer_key_cache) |
|
torch._dynamo.mark_static_address(new_layer_value_cache) |
|
self.key_cache.append(new_layer_key_cache) |
|
self.value_cache.append(new_layer_value_cache) |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
It is VERY important to index using a tensor, otherwise you introduce a copy to the device. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input |
|
to know how where to write in the cache. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
cache_position = cache_kwargs.get("cache_position") |
|
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) |
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) |
|
k_out = self.key_cache[layer_idx] |
|
v_out = self.value_cache[layer_idx] |
|
|
|
if cache_position is None: |
|
k_out.copy_(key_states) |
|
v_out.copy_(value_states) |
|
else: |
|
|
|
|
|
|
|
try: |
|
k_out.index_copy_(2, cache_position, key_states) |
|
v_out.index_copy_(2, cache_position, value_states) |
|
except NotImplementedError: |
|
|
|
k_out[:, :, cache_position] = key_states |
|
v_out[:, :, cache_position] = value_states |
|
|
|
return k_out, v_out |
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
|
"""Returns the sequence length of the cached states that were seen by the model.""" |
|
|
|
|
|
|
|
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
|
|
|
def get_max_length(self) -> Optional[int]: |
|
"""Returns the maximum sequence length of the cached states.""" |
|
return self.max_cache_len |
|
|
|
def reset(self): |
|
"""Resets the cache values while preserving the objects""" |
|
for layer_idx in range(len(self.key_cache)): |
|
|
|
self.key_cache[layer_idx].zero_() |
|
self.value_cache[layer_idx].zero_() |
|
|
|
|