Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,313 Bytes
be791d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from math import pi, log
import torch
from torch import nn, einsum
from einops import rearrange, repeat
# helper functions
def exists(val):
return val is not None
def broadcat(tensors, dim = -1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim = dim)
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, '... (d r) -> ... d r', r = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d r -> ... (d r)')
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1.):
freqs = freqs.to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
return torch.cat((t_left, t, t_right), dim = -1)
# learned rotation helpers
def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
if exists(freq_ranges):
rotations = einsum('..., f -> ... f', rotations, freq_ranges)
rotations = rearrange(rotations, '... r f -> ... (r f)')
rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
return apply_rotary_emb(rotations, t, start_index = start_index)
# classes
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
custom_freqs = None,
freqs_for = 'lang',
theta = 10000,
max_freq = 10,
num_freqs = 1,
learned_freq = False,
use_xpos = False,
xpos_scale_base = 512,
interpolate_factor = 1.,
theta_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
theta *= theta_rescale_factor ** (dim / (dim - 2))
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == 'lang':
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
elif freqs_for == 'constant':
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f'unknown modality {freqs_for}')
self.cache = dict()
self.cache_scale = dict()
# self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
self.register_buffer('freqs', freqs)
# interpolation factors
assert interpolate_factor >= 1.
self.interpolate_factor = interpolate_factor
# xpos
self.use_xpos = use_xpos
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = xpos_scale_base
self.register_buffer('scale', scale)
def get_seq_pos(self, seq_len, device, dtype, offset = 0):
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
def rotate_queries_or_keys(self, t, seq_dim = -2, offset = 0):
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
freqs = self.forward(lambda: self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset), cache_key = f'freqs:{seq_len}|offset:{offset}')
return apply_rotary_emb(freqs, t)
def rotate_queries_and_keys(self, q, k, seq_dim = -2):
assert self.use_xpos
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
freqs = self.forward(lambda: seq, cache_key = f'freqs:{seq_len}')
scale = self.get_scale(lambda: seq, cache_key = f'scale:{seq_len}').to(dtype)
rotated_q = apply_rotary_emb(freqs, q, scale = scale)
rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1)
return rotated_q, rotated_k
def get_scale(self, t, cache_key = None):
assert self.use_xpos
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if callable(t):
t = t()
scale = 1.
if self.use_xpos:
power = (t - len(t) // 2) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
if exists(cache_key):
self.cache[cache_key] = scale
return scale
def forward(self, t, cache_key = None):
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if callable(t):
t = t()
freqs = self.freqs
freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
if exists(cache_key):
self.cache[cache_key] = freqs
return freqs
|