Test2 / start.py
AlterM's picture
Duplicate from RisticksAI/ProfNet4
2f500b5
raw
history blame contribute delete
899 Bytes
import numpy as np
import pickle
import embedding
import random
import embed_set
import net
from tqdm import tqdm
from tensorflow.keras.models import load_model
top_p = 1
class SetLine:
def __init__(self, name, inp):
self.name = name
self.inp = embedding.getvec(name)
with open("set.pckl", "rb") as f: dset = pickle.load(f)
model = load_model("net.h5")
def top_closest_vectors(input_vector, top_p=1):
distances = [(np.linalg.norm((neuron.inp - input_vector)), ind) for ind, neuron in enumerate(dset)]
closest_indices = sorted(distances, reverse=False, key=lambda x:x[0])[:top_p]
return closest_indices
def generate(text):
vecs = 3*[np.zeros(net.vec_size),] + [embedding.getvec(x) for x in text.split("\n")]
vecs = vecs[-3:]
vecs = np.array([vecs,])
rvec = model.predict(vecs)[0]
return dset[random.choice(top_closest_vectors(rvec))[1]].name