|
from collections import deque |
|
|
|
import faiss |
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
from torch import nn |
|
|
|
|
|
class KNN: |
|
""" |
|
KNN for one element in batch. Handles all heads |
|
""" |
|
|
|
def __init__(self, num_heads, head_dim, memories_size=16000, shrink_size=None, cache=None): |
|
self.num_heads = num_heads |
|
self.head_dim = head_dim |
|
self.memories_size = memories_size |
|
self.shrink_size = shrink_size or memories_size * 1.1 |
|
self.indexes = [faiss.IndexFlat( |
|
self.head_dim, faiss.METRIC_INNER_PRODUCT) for _ in range(self.num_heads)] |
|
self.values = [deque([]) for _ in range(self.num_heads)] |
|
self.cache = cache |
|
|
|
def __del__(self): |
|
if hasattr(self, 'indexes'): |
|
del self.indexes |
|
del self.values |
|
|
|
def clear(self): |
|
for index in self.indexes: |
|
index.reset() |
|
|
|
for value in self.values: |
|
value.clear() |
|
|
|
def shrink(self): |
|
"""Shrinks index to memories_size""" |
|
|
|
for i, index in enumerate(self.indexes): |
|
if index.ntotal > self.shrink_size: |
|
to_delete = index.ntotal - self.memories_size |
|
index.remove_ids(np.arange(0, to_delete)) |
|
|
|
for _ in range(to_delete): |
|
self.values[i].popleft() |
|
|
|
def add(self, key, value): |
|
for i, k in enumerate(key): |
|
self.indexes[i].add(k) |
|
for i, v in enumerate(value): |
|
self.values[i].extend(v) |
|
|
|
if self.cache is not None: |
|
raise RuntimeError("Cache for KNN not implemented") |
|
|
|
|
|
self.shrink() |
|
|
|
def search(self, query, k=32): |
|
""" |
|
Searchs for query in keys' index. |
|
Returns k most relevant keys and corresponding values |
|
""" |
|
|
|
k = min(k, len(self.values[0])) |
|
|
|
if k <= 0: |
|
return torch.empty((query.shape[0], query.shape[1], 0, query.shape[2])),\ |
|
torch.empty( |
|
(query.shape[0], query.shape[1], 0, query.shape[2])) |
|
|
|
Ks, Vs = [], [] |
|
|
|
for i, q in enumerate(query): |
|
D, I, K = self.indexes[i].search_and_reconstruct(q, k=k) |
|
V = np.take(self.values[i], indices=I, axis=0) |
|
Ks.append(K) |
|
Vs.append(V) |
|
|
|
return np.stack(Ks, axis=0), np.stack(Vs, axis=0) |
|
|
|
|
|
class KNNLayer: |
|
""" |
|
KNN Attention layer. Handles KNN's for batch (every elemnt separately) |
|
""" |
|
|
|
def __init__(self, config, share_memory=True, batch_size=None, memory_size=16000, shrink_size=None, n_jobs=4, cache=None): |
|
if not share_memory and batch_size is None: |
|
raise RuntimeError( |
|
"If share_memory is False, batch_size should be passed") |
|
|
|
self.embed_dim = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
|
self.share_memory = share_memory |
|
self.batch_size = batch_size |
|
self.memory_size = memory_size |
|
self.shrink_size = shrink_size or self.memory_size * 1.1 |
|
self.closed = False |
|
|
|
if not share_memory: |
|
self.knns = [KNN(self.num_heads, self.head_dim, memory_size, |
|
self.shrink_size, cache=cache) for _ in range(self.batch_size)] |
|
else: |
|
self.knn = KNN(self.num_heads, self.head_dim, |
|
memory_size, self.shrink_size, cache=cache) |
|
|
|
faiss.omp_set_num_threads(n_jobs) |
|
|
|
def clear_batches(self, batch_indexes): |
|
if self.closed: |
|
return |
|
|
|
if not self.share_memory: |
|
for idx in batch_indexes: |
|
self.knns[idx].clear() |
|
|
|
def clear(self): |
|
if self.closed: |
|
return |
|
|
|
if self.share_memory: |
|
self.knn.clear() |
|
else: |
|
for idx in range(len(self.knns)): |
|
self.knns[idx].clear() |
|
|
|
def add(self, keys, values): |
|
if self.closed: |
|
return |
|
|
|
keys, values = keys.numpy(force=True), values.numpy(force=True) |
|
if not self.share_memory: |
|
for i, (key, value) in enumerate(zip(keys, values)): |
|
self.knns[i].add(key, value) |
|
else: |
|
for key, value in zip(keys, values): |
|
self.knn.add(key, value) |
|
|
|
def search(self, queries, k=32): |
|
queries = queries.numpy(force=True) |
|
keys, values = [], [] |
|
max_len = 0 |
|
|
|
if self.share_memory: |
|
for query in queries: |
|
key, value = self.knn.search(query, k) |
|
keys.append(key) |
|
values.append(value) |
|
max_len = max(max_len, key.shape[2]) |
|
else: |
|
for i, query in enumerate(queries): |
|
key, value = self.knns[i].search(query, k) |
|
keys.append(key) |
|
values.append(value) |
|
max_len = max(max_len, key.shape[2]) |
|
|
|
masks = np.ones((len(keys), max_len), dtype=np.float32) |
|
|
|
for i, (key, value) in enumerate(zip(keys, values)): |
|
l = key.shape[2] |
|
|
|
if l == max_len: |
|
continue |
|
elif l > max_len: |
|
raise RuntimeError("What? max_len is not max") |
|
|
|
sh = list(key.shape) |
|
sh[2] = max_len - sh[2] |
|
keys[i] = np.concatenate( |
|
(key, np.zeros(sh, dtype=np.float32)), axis=2) |
|
values[i] = np.concatenate( |
|
(value, np.zeros(sh, dtype=np.float32)), axis=2) |
|
masks[i, l:] = 0 |
|
|
|
return torch.from_numpy(np.stack(keys, axis=0)),\ |
|
torch.from_numpy(np.stack(values, axis=0)),\ |
|
torch.from_numpy(masks) |
|
|
|
def close(self): |
|
self.closed = True |
|
|
|
def open(self): |
|
self.closed = False |
|
|
|
def reset(self): |
|
self.open() |
|
self.clear() |
|
|
|
|
|
class ClearMemoryLayer(nn.Module): |
|
def __init__(self, knn_memory, bos_token, eos_token, next_layer): |
|
super().__init__() |
|
|
|
self.knn_memory = knn_memory |
|
self.bos_token = bos_token |
|
self.eos_token = eos_token |
|
self.next_layer = next_layer |
|
|
|
def _clear_if_token(self, tokens, token): |
|
batches_to_clear = (tokens == token).any(dim=-1).nonzero() |
|
|
|
if len(batches_to_clear) > 0: |
|
self.knn_memory.clear_batches(batches_to_clear[0]) |
|
|
|
def forward(self, tokens, *args, **kwargs): |
|
|
|
|
|
batches_to_clear = (tokens[:, 0] == self.bos_token).nonzero() |
|
|
|
if len(batches_to_clear) > 0: |
|
self.knn_memory.clear_batches(batches_to_clear[:, 0]) |
|
|
|
res = self.next_layer(tokens, *args, **kwargs) |
|
|
|
|
|
return res |
|
|