|
""" |
|
LoLCATs + ThunderKittens linear attention + sliding window for generation |
|
""" |
|
from typing import Optional, Tuple, List |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
try: |
|
from thunderkittens import hedgehog as tk_window_hedgehog_attention |
|
print(f"Successfully imported ThunderKittens for TK window attention") |
|
except: |
|
print(f"Failed to import ThunderKittens for TK window attention") |
|
|
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention |
|
from .linear_attention import LinearAttentionState |
|
|
|
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention): |
|
def __init__(self, *args, window_size: int = 64, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.train_attention = False |
|
self.base_inference = False |
|
self.window_size = 64 |
|
self.decode_window_size = 64 |
|
|
|
b, h, l, d = 1, 32, 8192, 128 |
|
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device='cuda') |
|
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device='cuda') |
|
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device='cuda') |
|
|
|
def forward(self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Tuple[int, torch.Tensor, torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
""" |
|
Forward pass with the option to compute attention weights multiple ways |
|
if self.train_attention is True |
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models |
|
""" |
|
b, l, _ = hidden_states.size() |
|
assert past_key_value is not None, "past_key_value must be provided for generation" |
|
assert self.train_attention is False, "train_attention is not supported for generation" |
|
assert self.base_inference is False, "base_inference is not supported for generation" |
|
assert use_cache is True, "use_cache must be True for generation" |
|
past_key_value.window_size = self.decode_window_size |
|
q, k, v, kv_seq_len = self.process_qkv(hidden_states, attention_mask, |
|
position_ids, past_key_value) |
|
if q.shape[2] == 1 and kv_seq_len > 1: |
|
f_q = self.feature_map_q(q) |
|
_kv = past_key_value.update_for_decoding(k, v, self.layer_idx, |
|
self.feature_map_k) |
|
k_cache, v_cache, kv_state, k_state = _kv |
|
|
|
window_factors = F.sigmoid(self.window_factors) |
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1 |
|
|
|
|
|
a_sm = torch.einsum('bhmd,bhnd->bhmn', q.float(), k_cache.float()) * (k.shape[-1] ** -0.5) |
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) |
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max) |
|
sum_sm = a_sm.sum(dim=-1, keepdim=True) |
|
|
|
|
|
y_true = (torch.einsum('bhmn,bhnd->bhmd', a_sm, v_cache.float()) |
|
+ linear_factors * torch.einsum('bhld,bhdf->bhlf', f_q.float(), kv_state.float())) |
|
sum_ln = linear_factors * torch.einsum( |
|
'bhld,bhnd->bhl', f_q.float(), k_state.float())[..., None] |
|
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) |
|
|
|
else: |
|
|
|
b, h, l, d = q.shape |
|
device = q.device |
|
|
|
|
|
|
|
|
|
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32)) |
|
alphas = (1 - betas if self.affine_attention_factors else |
|
torch.ones(betas.shape, dtype=torch.float32, device=device)) |
|
q_map = self.feature_map_q.mlp.layer |
|
k_map = self.feature_map_k.mlp.layer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tk_window_hedgehog_attention(q.contiguous(), k.contiguous(), v.contiguous(), |
|
self.y_true, self.k_state, self.kv_state, |
|
q_map, k_map, alphas, betas) |
|
|
|
past_key_value.update_with_kv(self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx) |
|
|
|
|
|
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) |
|
y_true = self.o_proj(y_true) |
|
return y_true, None, past_key_value |
|
|
|
|
|
class LinearAttentionTKWindowGenerationCache(LinearAttentionState): |
|
""" |
|
Class for `past_key_values` |
|
-> Alternative to KV cache; here we only maintain a “KV state” and “K state” |
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36) |
|
""" |
|
def __init__(self, window_size: int = 64) -> None: |
|
super().__init__() |
|
self._seen_tokens = 0 |
|
self._seen_tokens_by_layer: List[int] = [] |
|
self.window_size = window_size |
|
|
|
self.decode_kv_states: List[torch.Tensor] = [] |
|
self.decode_k_states: List[torch.Tensor] = [] |
|
self.k_cache: List[torch.Tensor] = [] |
|
self.v_cache: List[torch.Tensor] = [] |
|
|
|
def update_with_kv(self, |
|
kv_state: torch.Tensor, k_state: torch.Tensor, |
|
k: torch.Tensor, v: torch.Tensor, |
|
layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Update the cache with new KV and K states |
|
""" |
|
if layer_idx == 0: |
|
self._seen_tokens += k.shape[2] |
|
self._seen_tokens_by_layer.append(k.shape[2]) |
|
|
|
|
|
if len(self.decode_k_states) <= layer_idx: |
|
self.decode_kv_states.append(kv_state) |
|
self.decode_k_states.append(k_state) |
|
else: |
|
self.decode_kv_states[layer_idx] = self.decode_kv_states[layer_idx] + kv_state |
|
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state |
|
|
|
self.k_cache.append(k[:, :, -self.window_size:, :]) |
|
self.v_cache.append(v[:, :, -self.window_size:, :]) |
|
|
|
def update_for_decoding(self, k: torch.Tensor, v: torch.Tensor, |
|
layer_idx: int, feature_map_k: callable) -> None: |
|
""" |
|
Update the cache for decoding |
|
""" |
|
k_cache = self.k_cache[layer_idx] |
|
v_cache = self.v_cache[layer_idx] |
|
k_state = feature_map_k(k_cache[:, :, :1, :]) |
|
v_state = v_cache[:, :, :1, :] |
|
kv_state = torch.einsum('bhlf,bhld->bhfd', k_state.float(), v_state.float()).to(k.dtype) |
|
|
|
self.decode_kv_states[layer_idx] += kv_state |
|
self.decode_k_states[layer_idx] += k_state |
|
|
|
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2) |
|
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2) |
|
if layer_idx == 0: |
|
self._seen_tokens += k.shape[-2] |
|
self._seen_tokens_by_layer[layer_idx] += k.shape[-2] |
|
return (self.k_cache[layer_idx], self.v_cache[layer_idx], |
|
self.decode_kv_states[layer_idx], self.decode_k_states[layer_idx]) |