|
import numpy as np |
|
from numba import njit |
|
from tqdm import tqdm |
|
import math |
|
import random |
|
from matplotlib import pyplot as plt |
|
import pickle |
|
|
|
|
|
|
|
|
|
def text_to_arr(text: str): |
|
return np.array([ord(x) for x in text.lower()]) |
|
|
|
@njit |
|
def longest_common_substring(s1, s2): |
|
current_match_start = -1 |
|
current_match_end = -1 |
|
|
|
best_match_start = current_match_start |
|
best_match_end = current_match_end |
|
|
|
min_len = min(len(s1), len(s2)) |
|
for i in range(min_len): |
|
if s1[i] == s2[i]: |
|
current_match_start = current_match_end = i |
|
j = 0 |
|
while s1[i+j] == s2[i+j] and i+j < min_len: |
|
j += 1 |
|
current_match_end = current_match_start + j |
|
|
|
if current_match_end - current_match_start > best_match_end - best_match_start: |
|
best_match_start = current_match_start |
|
best_match_end = current_match_end |
|
|
|
return s1[best_match_start:best_match_end] |
|
|
|
def not_found_in(q, data): |
|
for l in data: |
|
count = 0 |
|
lq = len(q)-1 |
|
for v in l: |
|
if v == q[count]: |
|
count += 1 |
|
else: |
|
count = 0 |
|
if count == lq: |
|
return False |
|
return True |
|
|
|
class Layer: |
|
def __init__(self, mem_len: int = 100, max_size: int = 6): |
|
self.mem_len = mem_len |
|
self.common_strings = [] |
|
self.previously_seen = [] |
|
self.max_size = max_size+1 |
|
def __call__(self, input_arr, training: bool = True): |
|
o = [] |
|
li = len(input_arr) |
|
for i in range(li): |
|
for y, cs in enumerate(self.common_strings): |
|
if (i+cs.shape[0]) <= li and (input_arr[i:i+cs.shape[0]] == cs).all(): |
|
o.append(y) |
|
if training: |
|
cl = 0 |
|
n = None |
|
for i, line in enumerate(self.previously_seen): |
|
t = longest_common_substring(input_arr, line) |
|
l = len(t) |
|
if l > cl and l < self.max_size: |
|
cl = l |
|
n = i |
|
r = t |
|
if self.previously_seen != []: |
|
if n is not None and len(r) > 1: |
|
self.previously_seen.pop(n) |
|
if not_found_in(r, self.common_strings): |
|
self.common_strings.append(r) |
|
self.previously_seen = self.previously_seen[-self.mem_len:] |
|
self.previously_seen.append(input_arr) |
|
return o |
|
|
|
def comparefilter(f1, f2): |
|
o = 0 |
|
hss = 0.5 |
|
for k in f1: |
|
if k in f2 and k in f1: |
|
o += np.sum((f2[k] > hss)==(f1[k] > hss)) |
|
return (o >= len(f1)*hss) |
|
|
|
class StrConv: |
|
def __init__(self, filters: int, size: int = 4): |
|
self.filter_amount = filters |
|
self.filters = [{} for _ in range(filters)] |
|
self.bias = np.zeros((self.filter_amount,)) |
|
self.size = 3 |
|
def regularize(self): |
|
for n, f in enumerate(self.filters): |
|
for f2 in self.filters[:n]: |
|
if random.randint(0, 100) < 10 and comparefilter(f, f2): |
|
self.filters[n] = {} |
|
def __call__(self, input_arr, training: bool = True, debug=False): |
|
if len(input_arr) <= self.size: |
|
return [] |
|
o = np.zeros((input_arr.shape[0]-self.size, self.filter_amount)) |
|
for i in range(input_arr.shape[0]-self.size): |
|
for n, c in enumerate(input_arr[i:i+self.size]): |
|
for fn, f in enumerate(self.filters): |
|
if c in f: |
|
o[i, fn] += f[c][n] |
|
o += self.bias |
|
m = np.max(np.abs(o)) |
|
if m != 0: o /= m |
|
if debug: |
|
plt.imshow(o) |
|
plt.show() |
|
if training: |
|
for i in range(input_arr.shape[0]-self.size): |
|
for n, c in enumerate(input_arr[i:i+self.size]): |
|
for fn, f in enumerate(self.filters): |
|
if c in f: |
|
|
|
|
|
|
|
self.filters[fn][c][n] = o[i, fn]*0.1+f[c][n]*0.9 |
|
else: |
|
f[c] = np.random.uniform(0, 1, (self.size)) |
|
f[c][n] = o[i, fn] |
|
|
|
|
|
|
|
""" |
|
s = 0 |
|
for a in self.filters: |
|
for b in a: |
|
s += np.sum(b) |
|
if s > 100: |
|
s /= self.filter_amount |
|
for a in self.filters: |
|
for b in a: |
|
a[b] = (a[b]/s).astype(dtype=np.int64) |
|
""" |
|
self.bias -= np.sum(o, axis=0) |
|
|
|
maxed = np.zeros((o.shape[0],)) |
|
for i in range(maxed.shape[0]): |
|
maxed[i] = np.argmax(o[i]) |
|
return maxed |
|
|
|
with open("dataset.txt", "r") as f: |
|
lines = f.read().rstrip("\n").split("\n")[:40000] |
|
|
|
w = {} |
|
w2 = {} |
|
|
|
c = 0 |
|
|
|
|
|
|
|
|
|
with open("l1_large.pckl", "rb") as f: layer = pickle.load(f) |
|
with open("l2_large.pckl", "rb") as f: layer2 = pickle.load(f) |
|
with open("w1_large.pckl", "rb") as f: w = pickle.load(f) |
|
with open("w2_large.pckl", "rb") as f: w2 = pickle.load(f) |
|
""" |
|
for n, text in tqdm(enumerate(lines[:-1])): |
|
if text.strip() != "" and lines[n+1].strip() != "" and text != lines[n+1]: |
|
t = layer(text_to_arr(text), training=True) |
|
t = layer(text_to_arr(text), training=False) |
|
c += 1 |
|
# if c == 10: |
|
# c = 0 |
|
# layer.regularize() |
|
# layer2.regularize() |
|
if len(t) != 0: |
|
t2 = layer2(np.array(t), training=True) |
|
t2 = layer2(np.array(t), training=False) |
|
for a in t2: |
|
if a in w2: |
|
w2[a].append(n+1) |
|
else: |
|
w2[a] = [n+1,] |
|
for a in t: |
|
if a in w: |
|
w[a].append(n+1) |
|
else: |
|
w[a] = [n+1,] |
|
|
|
for n, text in tqdm(enumerate(lines[:200])): |
|
if text.strip() != "" and lines[n+1].strip() != "" and text != lines[n+1]: |
|
t = layer(text_to_arr(text), training=True) |
|
t = layer(text_to_arr(text), training=False) |
|
c += 1 |
|
# if c == 10: |
|
# c = 0 |
|
# layer.regularize() |
|
# layer2.regularize() |
|
if len(t) != 0: |
|
t2 = layer2(np.array(t), training=True) |
|
t2 = layer2(np.array(t), training=False) |
|
for a in t2: |
|
if a in w2: |
|
w2[a].append(n+1) |
|
else: |
|
w2[a] = [n+1,] |
|
for a in t: |
|
if a in w: |
|
w[a].append(n+1) |
|
else: |
|
w[a] = [n+1,] |
|
|
|
with open("l1_large.pckl", "wb") as f: pickle.dump(layer, f) |
|
with open("l2_large.pckl", "wb") as f: pickle.dump(layer2, f) |
|
with open("w1_large.pckl", "wb") as f: pickle.dump(w, f) |
|
with open("w2_large.pckl", "wb") as f: pickle.dump(w2, f) |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
print(len(lines), "responses available") |
|
|
|
import threeletterai |
|
|
|
while True: |
|
msg = input("Message: ") |
|
if len(msg) < 4: |
|
print(threeletterai.getresp(msg)) |
|
continue |
|
processed = layer(text_to_arr(msg), training=False) |
|
processed = np.array(processed) |
|
processed2 = layer2(processed, training=False) |
|
|
|
|
|
o = np.zeros(len(lines), dtype=np.int16) |
|
for a in processed: |
|
if a in w: |
|
o[w[a]] += 1 |
|
for a in processed2: |
|
if a in w2: |
|
o[w2[a]] += 1 |
|
print(lines[np.argmax(o)], f" {np.max(o)} sure") |
|
|
|
|