File size: 3,363 Bytes
edd38af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3eb9ffe
edd38af
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
from numba import njit
import math
import random
import pickle
import gradio as gr

def text_to_arr(text: str):
    return np.array([ord(x) for x in text.lower()])

@njit
def longest_common_substring(s1, s2):
    current_match_start = -1
    current_match_end = -1

    best_match_start = current_match_start
    best_match_end = current_match_end

    min_len = min(len(s1), len(s2))
    for i in range(min_len):
        if s1[i] == s2[i]:
            current_match_start = current_match_end = i
            j = 0
            while s1[i+j] == s2[i+j] and i+j < min_len:
               j += 1
            current_match_end = current_match_start + j

            if current_match_end - current_match_start > best_match_end - best_match_start:
                best_match_start = current_match_start
                best_match_end = current_match_end

    return s1[best_match_start:best_match_end]

def not_found_in(q, data):
    for l in data:
        count = 0
        lq = len(q)-1
        for v in l:
            if v == q[count]:
                count += 1
            else:
                count = 0
            if count == lq:
                return False
    return True

class Layer:
    def __init__(self, mem_len: int = 100, max_size: int = 6):
        self.mem_len = mem_len
        self.common_strings = []
        self.previously_seen = []
        self.max_size = max_size+1
    def __call__(self, input_arr, training: bool = True):
        o = []
        li = len(input_arr)
        for i in range(li):
            for y, common_substring in enumerate(self.common_strings):
                if (i+common_substring.shape[0]) <= li and (input_arr[i:i+common_substring.shape[0]] == common_substring).all():
                    o.append(y)
        if training:
            current_max_len = 0
            n = None
            for i, line in enumerate(self.previously_seen):
                t = longest_common_substring(input_arr, line)
                l = len(t)
                if l > current_max_len and l < self.max_size:
                    current_max_len = l
                    n = i
                    result = t
            if self.previously_seen != []:
                if n is not None and len(result) > 1:
                    self.previously_seen.pop(n)
                    if not_found_in(result, self.common_strings):
                        self.common_strings.append(result)
                self.previously_seen = self.previously_seen[-self.mem_len:]
            self.previously_seen.append(input_arr)
        return o

with open("l1_large.pckl", "rb") as f: layer = pickle.load(f)
with open("l2_large.pckl", "rb") as f: layer2 = pickle.load(f)
with open("w1_large.pckl", "rb") as f: w = pickle.load(f)
with open("w2_large.pckl", "rb") as f: w2 = pickle.load(f)

def generate(msg):
    if len(msg) < 4:
        return threeletterai.getresp(msg)
    processed = layer(text_to_arr(msg), training=False)
    processed = np.array(processed)
    processed2 = layer2(processed, training=False)
#    print(processed)
#    print(processed2)
    o = np.zeros(40000, dtype=np.int16)
    for a in processed:
        if a in w:
            o[w[a]] += 1
    for a in processed2:
        if a in w2:
            o[w2[a]] += 1
    return lines[np.argmax(o)]

app = gr.Interface(fn=generate, inputs="text", outputs="text")
app.launch()