import numpy as np | |
import pickle | |
import embedding | |
import random | |
import embed_set | |
import net | |
from tqdm import tqdm | |
from tensorflow.keras.models import load_model | |
top_p = 1 | |
class SetLine: | |
def __init__(self, name, inp): | |
self.name = name | |
self.inp = embedding.getvec(name) | |
with open("set.pckl", "rb") as f: dset = pickle.load(f) | |
model = load_model("net.h5") | |
def top_closest_vectors(input_vector, top_p=1): | |
distances = [(np.linalg.norm((neuron.inp - input_vector)), ind) for ind, neuron in enumerate(dset)] | |
closest_indices = sorted(distances, reverse=False, key=lambda x:x[0])[:top_p] | |
return closest_indices | |
def generate(text): | |
vecs = 3*[np.zeros(net.vec_size),] + [embedding.getvec(x) for x in text.split("\n")] | |
vecs = vecs[-3:] | |
vecs = np.array([vecs,]) | |
rvec = model.predict(vecs)[0] | |
return dset[random.choice(top_closest_vectors(rvec))[1]].name | |