lavawolfiee commited on
Commit
6bc49a9
1 Parent(s): 076e659
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ flagged/
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import GPT2TokenizerFast, GPT2LMHeadModel
4
+
5
+ from gpt2_knn_attention import GPT2KNNAttention
6
+ from knn_memory import KNNLayer, ClearMemoryLayer
7
+
8
+
9
+ def inject_knn_in_gpt2(model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8):
10
+ layer = model.transformer.h[layer_ind].attn
11
+ state = layer.state_dict()
12
+ knn_layer = GPT2KNNAttention(
13
+ config, knn_memory, device, is_cross_attention=False, layer_idx=layer.layer_idx)
14
+ knn_state = knn_layer.state_dict()
15
+
16
+ for k, v in state.items():
17
+ knn_state[k] = v
18
+
19
+ knn_layer.load_state_dict(knn_state)
20
+
21
+ model.transformer.h[8].attn = knn_layer
22
+ model.transformer = ClearMemoryLayer(
23
+ knn_memory, bos_token_id, eos_token_id, model.transformer)
24
+ model.eval()
25
+
26
+
27
+ model_name = "gpt2"
28
+ tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
29
+ model = GPT2LMHeadModel.from_pretrained(model_name)
30
+ config = model.config
31
+ model.eval()
32
+
33
+ knn_memory = KNNLayer(config, share_memory=False, batch_size=1)
34
+ bos_token_id, eos_token_id = tokenizer.bos_token_id, tokenizer.eos_token_id
35
+ bos_token, eos_token = tokenizer.bos_token, tokenizer.eos_token
36
+
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ knn_model = inject_knn_in_gpt2(
39
+ model, knn_memory, bos_token_id, eos_token_id, device, layer_ind=8)
40
+ knn_model.load_state_dict(torch.load('gpt2_knn_attention.pt'))
41
+
42
+
43
+ def generate(text, temperature, max_new_tokens, top_p):
44
+ encoded_input = tokenizer(text, return_tensors='pt')
45
+ output = model.generate(**encoded_input, do_sample=True,
46
+ max_new_tokens=int(max_new_tokens), temperature=temperature, top_p=top_p)
47
+ return tokenizer.decode(output[0])
48
+
49
+
50
+ desc = "Попытка повторить статью от Google (Memorizing Transformers)[https://arxiv.org/abs/2203.08913]. "\
51
+ "В ней вводиться новый слой **KNNAttention**, который использует approximate kNN в базе с (key, value), чтобы делать attention по большому контексту. Это позволяет расширить контекст трансформера до размера книг и статей, несильно замедляя его.\n\n"\
52
+ "Я написал свои **KNNAttention**, переписал слой **GPT2Attention**, чтобы он использовал KNNAttention, а также написал несколько вспомогательный классов для всего этого.\n\n"\
53
+ "Я сел писать это за **3 недели** до дедлайна, но все равно не довел до результата, которого изначально хотел. Но я доволен проделанной работой :)"
54
+
55
+
56
+ demo = gr.Interface(
57
+ fn=generate,
58
+ inputs=[gr.inputs.Textbox(lines=5, label="Input Text"),
59
+ gr.Slider(0.001, 2.0, step=0.05, value=0.8, label='temperature'),
60
+ gr.Slider(1, 512, step=1, value=32, label='max_new_tokens'),
61
+ gr.Slider(0.1, 1.0, step=0.02, value=0.92, label='top_p')],
62
+ outputs=gr.outputs.Textbox(label="Generated Text"),
63
+ description=desc,
64
+ title="Memorizing Transformers",
65
+ examples=[
66
+ ["My name is Lewis and I like to", 0.8, 32, 0.92]
67
+ ]
68
+ )
69
+
70
+ demo.launch()
batched_dataloader.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from transformers import DataCollatorWithPadding
3
+
4
+ # some utils for training
5
+
6
+
7
+ class BooksBatcherIter:
8
+ def __init__(self, data_iter, batch_size, tokenizer, chunk_size=1024):
9
+ self.data_iter = data_iter
10
+ self.batch_size = batch_size
11
+ self.chunk_size = chunk_size
12
+ self.batch_fns = [self._batch_fn()]
13
+ self.collate_fn = DataCollatorWithPadding(tokenizer)
14
+
15
+ def _batch_fn(self):
16
+ for book in self.data_iter:
17
+ for i in range(0, len(book), self.chunk_size):
18
+ yield book[i:i+self.chunk_size]
19
+
20
+ def __iter__(self) -> 'BooksBatcherIter':
21
+ return self
22
+
23
+ def __next__(self) -> Any:
24
+ batch = []
25
+
26
+ try:
27
+ for b in self.batch_fns:
28
+ batch.append(next(b))
29
+ except StopIteration:
30
+ raise StopIteration
31
+
32
+ return self.collate_fn(batch)
33
+
34
+
35
+ class BooksBatcher:
36
+ def __init__(self, dataset, batch_size, tokenizer) -> None:
37
+ self.batch_size = batch_size
38
+ self.tokenizer = tokenizer
39
+ self.dataloader = DataLoader(
40
+ dataset=dataset,
41
+ batch_size=None, # return raw samples
42
+ shuffle=True,
43
+ num_workers=2,
44
+ prefetch_factor=4
45
+ )
46
+
47
+ def __iter__(self) -> 'BooksBatcherIter':
48
+ return BooksBatcherIter(iter(self.dataloader), self.batch_size, self.tokenizer)
gpt2_knn_attention.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:610319abc523c64f595d4cb8fec2ff7faeab331b05722206c6c42656eae0bdff
3
+ size 510408492
gpt2_knn_attention.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
7
+
8
+
9
+ class GPT2KNNAttention(GPT2Attention):
10
+ def __init__(self, config, knn_memory, device, is_cross_attention=False, layer_idx=None, num_retrieve_memories=32):
11
+ super().__init__(config, is_cross_attention, layer_idx)
12
+
13
+ self.knn_memory = knn_memory
14
+ self.device = device
15
+ self.num_retrieve_memories = num_retrieve_memories
16
+ self.knn_attn_dropout = nn.Dropout(config.attn_pdrop)
17
+ self.attn_comb_bias = nn.Parameter(torch.empty(self.num_heads,))
18
+ nn.init.normal_(self.attn_comb_bias, mean=0.0, std=1.0)
19
+ # self.attn_comb_bias = nn.Parameter(torch.full((self.num_heads,), 1.0))
20
+
21
+ def _knn_attn(self, query, key, value, mask, head_mask=None):
22
+ query = query.unsqueeze(-2)
23
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
24
+
25
+ if self.scale_attn_weights:
26
+ attn_weights = attn_weights / torch.full(
27
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
28
+ )
29
+
30
+ # Layer-wise attention scaling
31
+ if self.scale_attn_by_inverse_layer_idx:
32
+ attn_weights = attn_weights / float(self.layer_idx + 1)
33
+
34
+ # if not self.is_cross_attention:
35
+ # raise RuntimeError("KNN attention is not yet implemented for !cross_attention")
36
+ # # if only "normal" attention layer implements causal mask
37
+ # query_length, key_length = query.size(-3), key.size(-3)
38
+ # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
39
+ # mask_value = torch.finfo(attn_weights.dtype).min
40
+ # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
41
+ # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
42
+ # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
43
+ # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
44
+
45
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
46
+
47
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
48
+ attn_weights = attn_weights.type(value.dtype)
49
+ attn_weights = self.knn_attn_dropout(attn_weights)
50
+
51
+ # masking missing keys
52
+ sh = mask.size()
53
+ attn_weights = attn_weights * mask.view((sh[0], 1, 1, 1, sh[1]))
54
+
55
+ # Mask heads if we want to
56
+ if head_mask is not None:
57
+ attn_weights = attn_weights * head_mask
58
+
59
+ attn_output = torch.matmul(attn_weights, value)
60
+ attn_output.squeeze_(dim=-2)
61
+
62
+ return attn_output
63
+
64
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
65
+ attn_output, attn_weights = super()._attn(
66
+ query, key, value, attention_mask, head_mask)
67
+ knn_key, knn_value, knn_mask = self.knn_memory.search(
68
+ query, self.num_retrieve_memories)
69
+ g = torch.sigmoid(self.attn_comb_bias)[:, None, None]
70
+
71
+ if knn_key.numel() == 0:
72
+ return attn_output * (1 - g), attn_weights
73
+
74
+ knn_key, knn_value, knn_mask = knn_key.to(
75
+ self.device), knn_value.to(self.device), knn_mask.to(self.device)
76
+ knn_attn_output = self._knn_attn(
77
+ query, knn_key, knn_value, knn_mask, head_mask)
78
+
79
+ # combining two attentions
80
+ attn = knn_attn_output * g + attn_output * (1 - g)
81
+
82
+ return attn, attn_weights
83
+
84
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
85
+ raise RuntimeError(
86
+ "KNN attention is not yet implemented for _upcast_and_reordered_attn")
87
+
88
+ def forward(
89
+ self,
90
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
91
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
92
+ attention_mask: Optional[torch.FloatTensor] = None,
93
+ head_mask: Optional[torch.FloatTensor] = None,
94
+ encoder_hidden_states: Optional[torch.Tensor] = None,
95
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
96
+ use_cache: Optional[bool] = False,
97
+ output_attentions: Optional[bool] = False,
98
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
99
+ if encoder_hidden_states is not None:
100
+ if not hasattr(self, "q_attn"):
101
+ raise ValueError(
102
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
103
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
104
+ )
105
+
106
+ query = self.q_attn(hidden_states)
107
+ key, value = self.c_attn(encoder_hidden_states).split(
108
+ self.split_size, dim=2)
109
+ attention_mask = encoder_attention_mask
110
+ else:
111
+ query, key, value = self.c_attn(
112
+ hidden_states).split(self.split_size, dim=2)
113
+
114
+ query = self._split_heads(query, self.num_heads, self.head_dim)
115
+ key = self._split_heads(key, self.num_heads, self.head_dim)
116
+ value = self._split_heads(value, self.num_heads, self.head_dim)
117
+
118
+ # normalization of queries and keys reduces the effect of staleness
119
+ query, key = F.normalize(query, dim=-1), F.normalize(key, dim=-1)
120
+ new_memories = (key, value)
121
+
122
+ if layer_past is not None:
123
+ past_key, past_value = layer_past
124
+ key = torch.cat((past_key, key), dim=-2)
125
+ value = torch.cat((past_value, value), dim=-2)
126
+
127
+ if use_cache is True:
128
+ present = (key, value)
129
+ else:
130
+ present = None
131
+
132
+ if self.reorder_and_upcast_attn:
133
+ raise RuntimeError("Not implemented")
134
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
135
+ query, key, value, attention_mask, head_mask)
136
+ else:
137
+ attn_output, attn_weights = self._attn(
138
+ query, key, value, attention_mask, head_mask)
139
+
140
+ attn_output = self._merge_heads(
141
+ attn_output, self.num_heads, self.head_dim)
142
+ attn_output = self.c_proj(attn_output)
143
+ attn_output = self.resid_dropout(attn_output)
144
+
145
+ outputs = (attn_output, present)
146
+ if output_attentions:
147
+ outputs += (attn_weights,)
148
+
149
+ self.knn_memory.add(*new_memories)
150
+
151
+ return outputs # a, present, (attentions)
knn_memory.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ import faiss
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from torch import nn
8
+
9
+
10
+ class KNN:
11
+ """
12
+ KNN for one element in batch. Handles all heads
13
+ """
14
+
15
+ def __init__(self, num_heads, head_dim, memories_size=16000, shrink_size=None, cache=None):
16
+ self.num_heads = num_heads
17
+ self.head_dim = head_dim
18
+ self.memories_size = memories_size
19
+ self.shrink_size = shrink_size or memories_size * 1.1
20
+ self.indexes = [faiss.IndexFlat(
21
+ self.head_dim, faiss.METRIC_INNER_PRODUCT) for _ in range(self.num_heads)]
22
+ self.values = [deque([]) for _ in range(self.num_heads)]
23
+ self.cache = cache
24
+
25
+ def __del__(self):
26
+ if hasattr(self, 'indexes'):
27
+ del self.indexes
28
+ del self.values
29
+
30
+ def clear(self):
31
+ for index in self.indexes:
32
+ index.reset()
33
+
34
+ for value in self.values:
35
+ value.clear()
36
+
37
+ def shrink(self):
38
+ """Shrinks index to memories_size"""
39
+
40
+ for i, index in enumerate(self.indexes):
41
+ if index.ntotal > self.shrink_size:
42
+ to_delete = index.ntotal - self.memories_size
43
+ index.remove_ids(np.arange(0, to_delete))
44
+
45
+ for _ in range(to_delete):
46
+ self.values[i].popleft()
47
+
48
+ def add(self, key, value):
49
+ for i, k in enumerate(key):
50
+ self.indexes[i].add(k)
51
+ for i, v in enumerate(value):
52
+ self.values[i].extend(v)
53
+
54
+ if self.cache is not None:
55
+ raise RuntimeError("Cache for KNN not implemented")
56
+ # self.cache.add(key)
57
+
58
+ self.shrink()
59
+
60
+ def search(self, query, k=32):
61
+ """
62
+ Searchs for query in keys' index.
63
+ Returns k most relevant keys and corresponding values
64
+ """
65
+
66
+ k = min(k, len(self.values[0]))
67
+
68
+ if k <= 0:
69
+ return torch.empty((query.shape[0], query.shape[1], 0, query.shape[2])),\
70
+ torch.empty(
71
+ (query.shape[0], query.shape[1], 0, query.shape[2]))
72
+
73
+ Ks, Vs = [], []
74
+
75
+ for i, q in enumerate(query):
76
+ D, I, K = self.indexes[i].search_and_reconstruct(q, k=k)
77
+ V = np.take(self.values[i], indices=I, axis=0)
78
+ Ks.append(K)
79
+ Vs.append(V)
80
+
81
+ return np.stack(Ks, axis=0), np.stack(Vs, axis=0)
82
+
83
+
84
+ class KNNLayer:
85
+ """
86
+ KNN Attention layer. Handles KNN's for batch (every elemnt separately)
87
+ """
88
+
89
+ def __init__(self, config, share_memory=True, batch_size=None, memory_size=16000, shrink_size=None, n_jobs=4, cache=None):
90
+ if not share_memory and batch_size is None:
91
+ raise RuntimeError(
92
+ "If share_memory is False, batch_size should be passed")
93
+
94
+ self.embed_dim = config.hidden_size
95
+ self.num_heads = config.num_attention_heads
96
+ self.head_dim = self.embed_dim // self.num_heads
97
+
98
+ self.share_memory = share_memory
99
+ self.batch_size = batch_size
100
+ self.memory_size = memory_size
101
+ self.shrink_size = shrink_size or self.memory_size * 1.1
102
+ self.closed = False
103
+
104
+ if not share_memory:
105
+ self.knns = [KNN(self.num_heads, self.head_dim, memory_size,
106
+ self.shrink_size, cache=cache) for _ in range(self.batch_size)]
107
+ else:
108
+ self.knn = KNN(self.num_heads, self.head_dim,
109
+ memory_size, self.shrink_size, cache=cache)
110
+
111
+ faiss.omp_set_num_threads(n_jobs)
112
+
113
+ def clear_batches(self, batch_indexes):
114
+ if self.closed:
115
+ return
116
+
117
+ if not self.share_memory:
118
+ for idx in batch_indexes:
119
+ self.knns[idx].clear()
120
+
121
+ def clear(self):
122
+ if self.closed:
123
+ return
124
+
125
+ if self.share_memory:
126
+ self.knn.clear()
127
+ else:
128
+ for idx in range(len(self.knns)):
129
+ self.knns[idx].clear()
130
+
131
+ def add(self, keys, values):
132
+ if self.closed:
133
+ return
134
+
135
+ keys, values = keys.numpy(force=True), values.numpy(force=True)
136
+ if not self.share_memory:
137
+ for i, (key, value) in enumerate(zip(keys, values)):
138
+ self.knns[i].add(key, value)
139
+ else:
140
+ for key, value in zip(keys, values):
141
+ self.knn.add(key, value)
142
+
143
+ def search(self, queries, k=32):
144
+ queries = queries.numpy(force=True)
145
+ keys, values = [], []
146
+ max_len = 0
147
+
148
+ if self.share_memory:
149
+ for query in queries:
150
+ key, value = self.knn.search(query, k)
151
+ keys.append(key)
152
+ values.append(value)
153
+ max_len = max(max_len, key.shape[2])
154
+ else:
155
+ for i, query in enumerate(queries):
156
+ key, value = self.knns[i].search(query, k)
157
+ keys.append(key)
158
+ values.append(value)
159
+ max_len = max(max_len, key.shape[2])
160
+
161
+ masks = np.ones((len(keys), max_len), dtype=np.float32)
162
+
163
+ for i, (key, value) in enumerate(zip(keys, values)):
164
+ l = key.shape[2]
165
+
166
+ if l == max_len:
167
+ continue
168
+ elif l > max_len:
169
+ raise RuntimeError("What? max_len is not max")
170
+
171
+ sh = list(key.shape)
172
+ sh[2] = max_len - sh[2]
173
+ keys[i] = np.concatenate(
174
+ (key, np.zeros(sh, dtype=np.float32)), axis=2)
175
+ values[i] = np.concatenate(
176
+ (value, np.zeros(sh, dtype=np.float32)), axis=2)
177
+ masks[i, l:] = 0
178
+
179
+ return torch.from_numpy(np.stack(keys, axis=0)),\
180
+ torch.from_numpy(np.stack(values, axis=0)),\
181
+ torch.from_numpy(masks)
182
+
183
+ def close(self):
184
+ self.closed = True
185
+
186
+ def open(self):
187
+ self.closed = False
188
+
189
+ def reset(self):
190
+ self.open()
191
+ self.clear()
192
+
193
+
194
+ class ClearMemoryLayer(nn.Module):
195
+ def __init__(self, knn_memory, bos_token, eos_token, next_layer):
196
+ super().__init__()
197
+
198
+ self.knn_memory = knn_memory
199
+ self.bos_token = bos_token
200
+ self.eos_token = eos_token
201
+ self.next_layer = next_layer
202
+
203
+ def _clear_if_token(self, tokens, token):
204
+ batches_to_clear = (tokens == token).any(dim=-1).nonzero()
205
+
206
+ if len(batches_to_clear) > 0:
207
+ self.knn_memory.clear_batches(batches_to_clear[0])
208
+
209
+ def forward(self, tokens, *args, **kwargs):
210
+ # self._clear_if_token(tokens, self.bos_token)
211
+
212
+ batches_to_clear = (tokens[:, 0] == self.bos_token).nonzero()
213
+
214
+ if len(batches_to_clear) > 0:
215
+ self.knn_memory.clear_batches(batches_to_clear[:, 0])
216
+
217
+ res = self.next_layer(tokens, *args, **kwargs)
218
+ # self._clear_if_token(tokens, self.eos_token)
219
+
220
+ return res
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.28.1
2
+ torch==2.0.0
3
+ faiss-cpu==1.7.4
4
+ gradio==3.27.0
5
+ numpy==1.21.2
vector_cache.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import numpy as np
5
+
6
+
7
+ class VectorCache:
8
+ """
9
+ Caches vectors on disk so one can later build an index on them (indexes like IVF requires big amount of vetores for building)
10
+ """
11
+
12
+ def __init__(self, filename='vector_cache.memmap', d=768, size=7000000):
13
+ self.filename = filename
14
+ self.offset_file = filename + '.offset'
15
+ self.d = d
16
+ self.size = size
17
+
18
+ if os.path.isfile(filename):
19
+ mode = 'r+'
20
+ self.f = open(self.offset_file, mode)
21
+ data = json.load(self.f)
22
+ self.offset = data[0]
23
+ self.length = data[1]
24
+ else:
25
+ mode = 'w+'
26
+ self.f = open(self.offset_file, mode)
27
+ self.offset = 0
28
+ self.length = 0
29
+
30
+ self.db = np.memmap(filename, dtype=np.float32, mode='w+',
31
+ shape=(size, d), order='C')
32
+
33
+ def sync_offset(self):
34
+ self.f.seek(0)
35
+ self.f.truncate(0)
36
+ self.f.write(json.dumps([self.offset, self.length]))
37
+
38
+ def close(self):
39
+ self.db.flush()
40
+ self.db.close()
41
+
42
+ self.sync_offset()
43
+ self.f.flush()
44
+ self.f.close()
45
+
46
+ def add(self, vs):
47
+ l = len(vs)
48
+ to_end = self.size - self.offset
49
+
50
+ if to_end < l:
51
+ self.add(vs[:to_end])
52
+ self.add(vs[to_end:])
53
+ return
54
+
55
+ self.db[self.offset:self.offset+l+1, :] = vs
56
+ self.offset = (self.offset + l + 1) % self.size
57
+ self.length = min(self.length + l, self.size)