lavawolfiee's picture
Finally
6bc49a9
from collections import deque
import faiss
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn
class KNN:
"""
KNN for one element in batch. Handles all heads
"""
def __init__(self, num_heads, head_dim, memories_size=16000, shrink_size=None, cache=None):
self.num_heads = num_heads
self.head_dim = head_dim
self.memories_size = memories_size
self.shrink_size = shrink_size or memories_size * 1.1
self.indexes = [faiss.IndexFlat(
self.head_dim, faiss.METRIC_INNER_PRODUCT) for _ in range(self.num_heads)]
self.values = [deque([]) for _ in range(self.num_heads)]
self.cache = cache
def __del__(self):
if hasattr(self, 'indexes'):
del self.indexes
del self.values
def clear(self):
for index in self.indexes:
index.reset()
for value in self.values:
value.clear()
def shrink(self):
"""Shrinks index to memories_size"""
for i, index in enumerate(self.indexes):
if index.ntotal > self.shrink_size:
to_delete = index.ntotal - self.memories_size
index.remove_ids(np.arange(0, to_delete))
for _ in range(to_delete):
self.values[i].popleft()
def add(self, key, value):
for i, k in enumerate(key):
self.indexes[i].add(k)
for i, v in enumerate(value):
self.values[i].extend(v)
if self.cache is not None:
raise RuntimeError("Cache for KNN not implemented")
# self.cache.add(key)
self.shrink()
def search(self, query, k=32):
"""
Searchs for query in keys' index.
Returns k most relevant keys and corresponding values
"""
k = min(k, len(self.values[0]))
if k <= 0:
return torch.empty((query.shape[0], query.shape[1], 0, query.shape[2])),\
torch.empty(
(query.shape[0], query.shape[1], 0, query.shape[2]))
Ks, Vs = [], []
for i, q in enumerate(query):
D, I, K = self.indexes[i].search_and_reconstruct(q, k=k)
V = np.take(self.values[i], indices=I, axis=0)
Ks.append(K)
Vs.append(V)
return np.stack(Ks, axis=0), np.stack(Vs, axis=0)
class KNNLayer:
"""
KNN Attention layer. Handles KNN's for batch (every elemnt separately)
"""
def __init__(self, config, share_memory=True, batch_size=None, memory_size=16000, shrink_size=None, n_jobs=4, cache=None):
if not share_memory and batch_size is None:
raise RuntimeError(
"If share_memory is False, batch_size should be passed")
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.share_memory = share_memory
self.batch_size = batch_size
self.memory_size = memory_size
self.shrink_size = shrink_size or self.memory_size * 1.1
self.closed = False
if not share_memory:
self.knns = [KNN(self.num_heads, self.head_dim, memory_size,
self.shrink_size, cache=cache) for _ in range(self.batch_size)]
else:
self.knn = KNN(self.num_heads, self.head_dim,
memory_size, self.shrink_size, cache=cache)
faiss.omp_set_num_threads(n_jobs)
def clear_batches(self, batch_indexes):
if self.closed:
return
if not self.share_memory:
for idx in batch_indexes:
self.knns[idx].clear()
def clear(self):
if self.closed:
return
if self.share_memory:
self.knn.clear()
else:
for idx in range(len(self.knns)):
self.knns[idx].clear()
def add(self, keys, values):
if self.closed:
return
keys, values = keys.numpy(force=True), values.numpy(force=True)
if not self.share_memory:
for i, (key, value) in enumerate(zip(keys, values)):
self.knns[i].add(key, value)
else:
for key, value in zip(keys, values):
self.knn.add(key, value)
def search(self, queries, k=32):
queries = queries.numpy(force=True)
keys, values = [], []
max_len = 0
if self.share_memory:
for query in queries:
key, value = self.knn.search(query, k)
keys.append(key)
values.append(value)
max_len = max(max_len, key.shape[2])
else:
for i, query in enumerate(queries):
key, value = self.knns[i].search(query, k)
keys.append(key)
values.append(value)
max_len = max(max_len, key.shape[2])
masks = np.ones((len(keys), max_len), dtype=np.float32)
for i, (key, value) in enumerate(zip(keys, values)):
l = key.shape[2]
if l == max_len:
continue
elif l > max_len:
raise RuntimeError("What? max_len is not max")
sh = list(key.shape)
sh[2] = max_len - sh[2]
keys[i] = np.concatenate(
(key, np.zeros(sh, dtype=np.float32)), axis=2)
values[i] = np.concatenate(
(value, np.zeros(sh, dtype=np.float32)), axis=2)
masks[i, l:] = 0
return torch.from_numpy(np.stack(keys, axis=0)),\
torch.from_numpy(np.stack(values, axis=0)),\
torch.from_numpy(masks)
def close(self):
self.closed = True
def open(self):
self.closed = False
def reset(self):
self.open()
self.clear()
class ClearMemoryLayer(nn.Module):
def __init__(self, knn_memory, bos_token, eos_token, next_layer):
super().__init__()
self.knn_memory = knn_memory
self.bos_token = bos_token
self.eos_token = eos_token
self.next_layer = next_layer
def _clear_if_token(self, tokens, token):
batches_to_clear = (tokens == token).any(dim=-1).nonzero()
if len(batches_to_clear) > 0:
self.knn_memory.clear_batches(batches_to_clear[0])
def forward(self, tokens, *args, **kwargs):
# self._clear_if_token(tokens, self.bos_token)
batches_to_clear = (tokens[:, 0] == self.bos_token).nonzero()
if len(batches_to_clear) > 0:
self.knn_memory.clear_batches(batches_to_clear[:, 0])
res = self.next_layer(tokens, *args, **kwargs)
# self._clear_if_token(tokens, self.eos_token)
return res