File size: 6,578 Bytes
ae81e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
"""
Attention conversion helpers
"""
from functools import partial
from tqdm import tqdm
import torch.nn as nn
def convert_attention(model: nn.Module,
attention_config: dict,
train_attention: bool = False,
remove_base_attn: bool = True,):
"""
Call to convert all attention layers
"""
softmax_attns = []
if 'softmax_attentions' in attention_config:
softmax_attns = attention_config['softmax_attentions']
if attention_config.attention_type != 'softmax':
layers = traverse_layers(model)
for layer_idx, layer in enumerate(tqdm(layers, desc='Converting attentions...')):
if layer_idx not in softmax_attns:
layer.self_attn = convert_llama_attention(
layer, attention_config, layers, train_attention, remove_base_attn,
)
layer.self_attn.converted = True
else: # Freeze any preserved softmax attention layers
for p in layer.parameters():
p.requires_grad = False
else:
print(f'-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions')
return model
def toggle_attention(llama_model: nn.Module, train: bool = False):
"""
Make attentions trainable if train is True
-> Set train_attention = False when finetuning
"""
for layer in traverse_layers(llama_model):
layer.self_attn.train_attention = train
return llama_model
def remove_base_attention(llama_model: nn.Module):
"""
Remove teacher attention after distillation (if we keep it)
"""
for layer in traverse_layers(llama_model):
if getattr(layer.self_attn, 'base_attn', False):
del layer.self_attn.base_attn
return llama_model
def traverse_layers(model: nn.Module, verbose: bool = False):
"""
Return list of model layers
"""
try:
layers = model.model.layers
if verbose:
print('-> Loading from model.model.layers')
except AttributeError as e: # if base model
if verbose:
print(e)
try:
layers = model.layers
if verbose:
print('-> Loading from model.layers')
except AttributeError as e1: # If we make a PEFT model
if verbose:
print(e1)
layers = model.base_model.model.model.layers
if verbose:
print('-> Loading from model.base_model.model.model.layers')
return layers
def convert_llama_attention(layer: nn.Module,
attention_config: dict,
layers: list[nn.Module], # list of layers
train_attention: bool = False,
remove_base_attn: bool = True):
"""
Converts a single layer's attention layer as specified by attention_config
"""
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
def get_attention(attention_type: str, **kwargs: any):
"""
Get the linear attention class; either purely linear or linear with sliding window
-> 'linear' == 'lolcats_llama'
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
"""
kwargs['attention_type'] = attention_type
if attention_type == 'lolcats_llama':
from .linear_attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == 'lolcats_llama_window_tk':
from .linear_attention import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == 'lolcats_llama_window_sw':
from .linear_attention import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == 'lolcats_llama_window_sw_linear':
from .linear_attention.linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
## Experimental chunked linear attentions below
elif attention_type == 'lolcats_long_llama_window_tk':
from .linear_attention import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == 'lolcats_long_llama_window_sw':
from .linear_attention import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
## TK generation build (requires Thunderkittens)
elif attention_type == 'lolcats_llama_window_tk_gen':
from .linear_attention import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
else:
print(f'-> attention_type {attention_type} not handled... returning None')
return None
def get_attention_cache(attention_type: str, past_key_values: any = None):
"""
Determine how we store past keys and values when generating
"""
if attention_type is None:
return past_key_values
# print(f'Returning attention cache based on attention_type == {attention_type}')
elif 'lolcats_llama_window_tk_gen' in attention_type:
from .linear_attention import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif 'llama_window_tk' in attention_type:
from .linear_attention import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif 'llama_window_sw' in attention_type:
from .linear_attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif 'llama_window_sw_linear' in attention_type:
from .linear_attention import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
## TK generation build (requires Thunderkittens)
elif attention_type == 'lolcats_llama_window_tk_gen':
from .linear_attention.linear_window_attention_tk_gen import LinearAttentionTKWindowGenerationCache
return LinearAttentionTKWindowGenerationCache()
elif 'softmax' in attention_type:
return past_key_values
else:
from .linear_attention import LinearAttentionState
return LinearAttentionState()
|