ierhon commited on
Commit
60aeba2
1 Parent(s): 3eb9ffe

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +244 -0
main.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from numba import njit
3
+ from tqdm import tqdm
4
+ import math
5
+ import random
6
+ from matplotlib import pyplot as plt
7
+ import pickle
8
+
9
+
10
+ # whitelist = "ёйцукенгшщзхъфывапролджэячсмитьбю "
11
+
12
+ def text_to_arr(text: str):
13
+ return np.array([ord(x) for x in text.lower()])
14
+
15
+ @njit
16
+ def longest_common_substring(s1, s2):
17
+ current_match_start = -1
18
+ current_match_end = -1
19
+
20
+ best_match_start = current_match_start
21
+ best_match_end = current_match_end
22
+
23
+ min_len = min(len(s1), len(s2))
24
+ for i in range(min_len):
25
+ if s1[i] == s2[i]:
26
+ current_match_start = current_match_end = i
27
+ j = 0
28
+ while s1[i+j] == s2[i+j] and i+j < min_len:
29
+ j += 1
30
+ current_match_end = current_match_start + j
31
+
32
+ if current_match_end - current_match_start > best_match_end - best_match_start:
33
+ best_match_start = current_match_start
34
+ best_match_end = current_match_end
35
+
36
+ return s1[best_match_start:best_match_end]
37
+
38
+ def not_found_in(q, data):
39
+ for l in data:
40
+ count = 0
41
+ lq = len(q)-1
42
+ for v in l:
43
+ if v == q[count]:
44
+ count += 1
45
+ else:
46
+ count = 0
47
+ if count == lq:
48
+ return False
49
+ return True
50
+
51
+ class Layer:
52
+ def __init__(self, mem_len: int = 100, max_size: int = 6):
53
+ self.mem_len = mem_len
54
+ self.common_strings = []
55
+ self.previously_seen = []
56
+ self.max_size = max_size+1
57
+ def __call__(self, input_arr, training: bool = True):
58
+ o = []
59
+ li = len(input_arr)
60
+ for i in range(li):
61
+ for y, cs in enumerate(self.common_strings):
62
+ if (i+cs.shape[0]) <= li and (input_arr[i:i+cs.shape[0]] == cs).all():
63
+ o.append(y)
64
+ if training:
65
+ cl = 0
66
+ n = None
67
+ for i, line in enumerate(self.previously_seen):
68
+ t = longest_common_substring(input_arr, line)
69
+ l = len(t)
70
+ if l > cl and l < self.max_size:
71
+ cl = l
72
+ n = i
73
+ r = t
74
+ if self.previously_seen != []:
75
+ if n is not None and len(r) > 1:
76
+ self.previously_seen.pop(n)
77
+ if not_found_in(r, self.common_strings):
78
+ self.common_strings.append(r)
79
+ self.previously_seen = self.previously_seen[-self.mem_len:]
80
+ self.previously_seen.append(input_arr)
81
+ return o
82
+
83
+ def comparefilter(f1, f2):
84
+ o = 0
85
+ hss = 0.5
86
+ for k in f1:
87
+ if k in f2 and k in f1:
88
+ o += np.sum((f2[k] > hss)==(f1[k] > hss))
89
+ return (o >= len(f1)*hss)
90
+
91
+ class StrConv:
92
+ def __init__(self, filters: int, size: int = 4):
93
+ self.filter_amount = filters
94
+ self.filters = [{} for _ in range(filters)] # [{43: [3 2 0 3]},]
95
+ self.bias = np.zeros((self.filter_amount,))
96
+ self.size = 3
97
+ def regularize(self):
98
+ for n, f in enumerate(self.filters):
99
+ for f2 in self.filters[:n]:
100
+ if random.randint(0, 100) < 10 and comparefilter(f, f2):
101
+ self.filters[n] = {}
102
+ def __call__(self, input_arr, training: bool = True, debug=False):
103
+ if len(input_arr) <= self.size:
104
+ return []
105
+ o = np.zeros((input_arr.shape[0]-self.size, self.filter_amount))
106
+ for i in range(input_arr.shape[0]-self.size):
107
+ for n, c in enumerate(input_arr[i:i+self.size]):
108
+ for fn, f in enumerate(self.filters):
109
+ if c in f:
110
+ o[i, fn] += f[c][n]
111
+ o += self.bias
112
+ m = np.max(np.abs(o))
113
+ if m != 0: o /= m
114
+ if debug:
115
+ plt.imshow(o)
116
+ plt.show()
117
+ if training:
118
+ for i in range(input_arr.shape[0]-self.size):
119
+ for n, c in enumerate(input_arr[i:i+self.size]):
120
+ for fn, f in enumerate(self.filters):
121
+ if c in f:
122
+ # s = np.sum(f[c])
123
+ # if s > 1000:
124
+ # f[c] = (f[c]/(s/(self.size*1000))).astype(np.int64)
125
+ self.filters[fn][c][n] = o[i, fn]*0.1+f[c][n]*0.9
126
+ else:
127
+ f[c] = np.random.uniform(0, 1, (self.size))
128
+ f[c][n] = o[i, fn]
129
+ # for t in range(self.size, input_arr.shape[0]):
130
+ # for f in range(self.filter_amount):
131
+ # self.filters[f] = o[t-self.size, f]
132
+ """
133
+ s = 0
134
+ for a in self.filters:
135
+ for b in a:
136
+ s += np.sum(b)
137
+ if s > 100:
138
+ s /= self.filter_amount
139
+ for a in self.filters:
140
+ for b in a:
141
+ a[b] = (a[b]/s).astype(dtype=np.int64)
142
+ """
143
+ self.bias -= np.sum(o, axis=0)# / o.shape[0]
144
+ # print(o)
145
+ maxed = np.zeros((o.shape[0],)) # could have different outputs, not only max of o, like o>(self.size//2) or o without processing
146
+ for i in range(maxed.shape[0]):
147
+ maxed[i] = np.argmax(o[i])
148
+ return maxed
149
+
150
+ with open("dataset.txt", "r") as f:
151
+ lines = f.read().rstrip("\n").split("\n")[:40000]
152
+
153
+ w = {}
154
+ w2 = {}
155
+
156
+ c = 0
157
+
158
+ #layer = Layer(mem_len=1000, max_size=4)
159
+ #layer2 = Layer(mem_len=1000, max_size=6)
160
+
161
+ with open("l1_large.pckl", "rb") as f: layer = pickle.load(f)
162
+ with open("l2_large.pckl", "rb") as f: layer2 = pickle.load(f)
163
+ with open("w1_large.pckl", "rb") as f: w = pickle.load(f)
164
+ with open("w2_large.pckl", "rb") as f: w2 = pickle.load(f)
165
+ """
166
+ for n, text in tqdm(enumerate(lines[:-1])):
167
+ if text.strip() != "" and lines[n+1].strip() != "" and text != lines[n+1]:
168
+ t = layer(text_to_arr(text), training=True)
169
+ t = layer(text_to_arr(text), training=False)
170
+ c += 1
171
+ # if c == 10:
172
+ # c = 0
173
+ # layer.regularize()
174
+ # layer2.regularize()
175
+ if len(t) != 0:
176
+ t2 = layer2(np.array(t), training=True)
177
+ t2 = layer2(np.array(t), training=False)
178
+ for a in t2:
179
+ if a in w2:
180
+ w2[a].append(n+1)
181
+ else:
182
+ w2[a] = [n+1,]
183
+ for a in t:
184
+ if a in w:
185
+ w[a].append(n+1)
186
+ else:
187
+ w[a] = [n+1,]
188
+
189
+ for n, text in tqdm(enumerate(lines[:200])):
190
+ if text.strip() != "" and lines[n+1].strip() != "" and text != lines[n+1]:
191
+ t = layer(text_to_arr(text), training=True)
192
+ t = layer(text_to_arr(text), training=False)
193
+ c += 1
194
+ # if c == 10:
195
+ # c = 0
196
+ # layer.regularize()
197
+ # layer2.regularize()
198
+ if len(t) != 0:
199
+ t2 = layer2(np.array(t), training=True)
200
+ t2 = layer2(np.array(t), training=False)
201
+ for a in t2:
202
+ if a in w2:
203
+ w2[a].append(n+1)
204
+ else:
205
+ w2[a] = [n+1,]
206
+ for a in t:
207
+ if a in w:
208
+ w[a].append(n+1)
209
+ else:
210
+ w[a] = [n+1,]
211
+
212
+ with open("l1_large.pckl", "wb") as f: pickle.dump(layer, f)
213
+ with open("l2_large.pckl", "wb") as f: pickle.dump(layer2, f)
214
+ with open("w1_large.pckl", "wb") as f: pickle.dump(w, f)
215
+ with open("w2_large.pckl", "wb") as f: pickle.dump(w2, f)
216
+ """
217
+ # print(layer.filters)
218
+
219
+ #for arr in layer.common_strings:
220
+ # print(''.join([chr(a) for a in arr]))
221
+
222
+ print(len(lines), "responses available")
223
+
224
+ import threeletterai
225
+
226
+ while True:
227
+ msg = input("Message: ")
228
+ if len(msg) < 4:
229
+ print(threeletterai.getresp(msg))
230
+ continue
231
+ processed = layer(text_to_arr(msg), training=False)
232
+ processed = np.array(processed)
233
+ processed2 = layer2(processed, training=False)
234
+ # print(processed)
235
+ # print(processed2)
236
+ o = np.zeros(len(lines), dtype=np.int16)
237
+ for a in processed:
238
+ if a in w:
239
+ o[w[a]] += 1
240
+ for a in processed2:
241
+ if a in w2:
242
+ o[w2[a]] += 1
243
+ print(lines[np.argmax(o)], f" {np.max(o)} sure")
244
+