File size: 899 Bytes
2f500b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import pickle
import embedding
import random
import embed_set
import net
from tqdm import tqdm
from tensorflow.keras.models import load_model

top_p = 1

class SetLine:
    def __init__(self, name, inp):
        self.name = name
        self.inp = embedding.getvec(name)

with open("set.pckl", "rb") as f: dset = pickle.load(f)

model = load_model("net.h5")

def top_closest_vectors(input_vector, top_p=1):
    distances = [(np.linalg.norm((neuron.inp - input_vector)), ind) for ind, neuron in enumerate(dset)]
    closest_indices = sorted(distances, reverse=False, key=lambda x:x[0])[:top_p]
    return closest_indices

def generate(text):
    vecs = 3*[np.zeros(net.vec_size),] + [embedding.getvec(x) for x in text.split("\n")]
    vecs = vecs[-3:]
    vecs = np.array([vecs,])
    rvec = model.predict(vecs)[0]
    return dset[random.choice(top_closest_vectors(rvec))[1]].name