Spaces:
Sleeping
Sleeping
import torch | |
class KVCache: | |
""" | |
A key-value cache for the model. | |
This class provides a mechanism to maintain a growing cache of keys and values, | |
particularly useful for models that benefit from caching previous states, | |
like transformers during autoregressive decoding. | |
Attributes: | |
data (torch.Tensor): The tensor storing keys and values. | |
current_length (int): Current length of the data being stored. | |
""" | |
def __init__(self, data, current_length): | |
""" | |
Initialize the KVCache. | |
Args: | |
data (torch.Tensor): Initial tensor to store the keys and values. | |
current_length (int): Initial length of the data. | |
""" | |
self.data = data | |
self.current_length = current_length | |
def shape(self): | |
"""Return the shape of the data tensor with updated length.""" | |
return ( | |
self.data.shape[0], | |
self.data.shape[1], | |
self.current_length.item(), | |
self.data.shape[3], | |
) | |
def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2): | |
""" | |
Copy values from the current data at specified indices to a new location. | |
Args: | |
indices (torch.Tensor): Indices of the data tensor to be copied. | |
prev_length (int): Previous length before adding new data. | |
dim (int, optional): Dimension along which copying should be performed. Default is 2. | |
""" | |
tgt = self.data.index_select(dim, indices) | |
dst = self.data.narrow(dim, prev_length, tgt.shape[dim]) | |
dst.copy_(tgt, non_blocking=True) | |
self.current_length.fill_(prev_length + tgt.shape[dim]) | |
def cat(self, tensor: torch.Tensor, dim: int = 2): | |
""" | |
Concatenate the given tensor with the current data. | |
Args: | |
tensor (torch.Tensor): The tensor to be concatenated. | |
dim (int, optional): The dimension along which concatenation should be done. Default is 2. | |
Returns: | |
torch.Tensor: The data tensor after concatenation up to the current length. | |
""" | |
dst = self.data.narrow(dim, self.current_length, tensor.shape[dim]) | |
dst.copy_(tensor) | |
self.current_length.add_(tensor.shape[dim]) | |
return torch.narrow(self.data, 2, 0, self.current_length) | |
def initialize_past_key_values(model): | |
""" | |
Initialize past key and value states for a given transformer model. | |
This function prepares key-value cache structures for the model, allowing it to store and reuse | |
past key and value states during autoregressive decoding, which can improve efficiency. | |
Args: | |
model (nn.Module): The transformer model for which past key-value states need to be initialized. | |
Returns: | |
tuple: | |
- past_key_values (list): A list of KVCache objects for each layer in the model. | |
- past_key_values_data (torch.Tensor): The tensor that will store all keys and values. | |
- current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache. | |
""" | |
# Extracting configuration from the model | |
config = model.config | |
# Initializing the batch size to 1, this can be modified if different batch sizes are required | |
batch_size = 1 | |
# Initializing a tensor to store past keys and values for all layers | |
devices=[] | |
for i in range(config.num_hidden_layers): | |
try: | |
device = model.model.layers[i].self_attn.q_proj.weight.device | |
except: | |
device=model.layers[i].self_attn.q_proj.weight.device | |
devices.append(device) | |
past_key_values_data_list=[] | |
startnum=0 | |
startdevice=devices[0] | |
for id,i in enumerate(devices): | |
if startdevice!=i: | |
past_key_values_data = torch.zeros( | |
startnum * 2, | |
batch_size, | |
config.num_key_value_heads, | |
config.max_position_embeddings, | |
config.hidden_size // config.num_attention_heads, | |
device=startdevice, | |
dtype=model.dtype, | |
) | |
past_key_values_data_list.append(past_key_values_data) | |
startdevice = i | |
startnum=0 | |
startnum += 1 | |
past_key_values_data = torch.zeros( | |
startnum * 2, | |
batch_size, | |
config.num_key_value_heads, | |
config.max_position_embeddings, | |
config.hidden_size // config.num_attention_heads, | |
device=startdevice, | |
dtype=model.dtype, | |
) | |
past_key_values_data_list.append(past_key_values_data) | |
# Initialize tensor to store the current length of the cached data for all layers. | |
# [IMPORTANT] It needs to be kept on CPU for quick access and updates. | |
current_length_data = torch.zeros( | |
config.num_hidden_layers * 2, dtype=torch.long, device="cpu" | |
) | |
# Creating a KVCache for each pair of key and value in all layers | |
past_key_values = [] * config.num_hidden_layers | |
bias=0 | |
start_data_m=devices[0].index | |
for i in range(config.num_hidden_layers): | |
data_m=devices[i].index | |
if data_m!=start_data_m: | |
bias=0 | |
start_data_m=data_m | |
past_key_values.append( | |
[ | |
KVCache(past_key_values_data_list[data_m-devices[0].index][2*bias + j], current_length_data[i * 2 + j]) | |
for j in range(2) | |
] | |
) | |
bias+=1 | |
return past_key_values, past_key_values_data_list, current_length_data | |