|
import tensorflow as tf |
|
import numpy |
|
model = tf.keras.models.load_model('./model/trained_model.keras', safe_mode=False, compile=False) |
|
|
|
|
|
model.summary() |
|
|
|
|
|
def predict(question1, question2, threshold, verbose=False): |
|
"""Function for predicting if two questions are duplicates. |
|
|
|
Args: |
|
question1 (str): First question. |
|
question2 (str): Second question. |
|
threshold (float): Desired threshold. |
|
verbose (bool, optional): If the results should be printed out. Defaults to False. |
|
|
|
Returns: |
|
bool: True if the questions are duplicates, False otherwise. |
|
""" |
|
generator = tf.data.Dataset.from_tensor_slices((([question1], [question2]),None)).batch(batch_size=1) |
|
|
|
|
|
|
|
|
|
v1v2 = model.predict(generator) |
|
out_size = v1v2.shape[1] |
|
|
|
v1 = v1v2[:,:int(out_size/2)] |
|
v2 = v1v2[:,int(out_size/2):] |
|
print(v1.shape) |
|
|
|
|
|
d = tf.reduce_sum(v1 * v2) |
|
|
|
res = d > threshold |
|
|
|
|
|
|
|
if(verbose): |
|
print("Q1 = ", question1, "\nQ2 = ", question2) |
|
print("d = ", d.numpy()) |
|
print("res = ", res.numpy()) |
|
|
|
return d.numpy(), res.numpy() |