Visual semantic with BERT-CNN
This model can be used to assign an object-to-caption semantic relatedness score, which is valuable for (1) caption diverse re-ranking, and (2) generate soft labels for caption filtering when scraping text-to-captions from the internet.
To take advantage of the overlapping between the visual context and the caption, and to extract global information from each visual (i.e., object, scene, etc) we use BERT as an embedding layer followed by a shallow CNN (tri-gram kernel) (Kim, 2014).
For datasets that are less than 100K please have look at our shallow model
The model is trained with a strict filter of 0.4 similarity distance thresholds between the object and its related caption.
For a quick start please have a look at this colab
For the dataset
conda create -n BERT_visual python=3.6 anaconda
conda activate BERT_visual
pip install tensorflow==1.15.0
pip install --upgrade tensorflow_hub==0.7.0
git clone https://github.com/gaphex/bert_experimental/
import tensorflow as tf
import numpy as np
import pandas as pd
import sys
from sklearn.model_selection import train_test_split
sys.path.insert(0, "bert_experimental")
from bert_experimental.finetuning.text_preprocessing import build_preprocessor
from bert_experimental.finetuning.graph_ops import load_graph
df = pd.read_csv("test.tsv", sep='\t')
texts = []
delimiter = " ||| "
for vis, cap in zip(df.visual.tolist(), df.caption.tolist()):
texts.append(delimiter.join((str(vis), str(cap))))
texts = np.array(texts)
trX, tsX = train_test_split(texts, shuffle=False, test_size=0.01)
restored_graph = load_graph("frozen_graph.pb")
graph_ops = restored_graph.get_operations()
input_op, output_op = graph_ops[0].name, graph_ops[-1].name
print(input_op, output_op)
x = restored_graph.get_tensor_by_name(input_op + ':0')
y = restored_graph.get_tensor_by_name(output_op + ':0')
preprocessor = build_preprocessor("vocab.txt", 64)
py_func = tf.numpy_function(preprocessor, [x], [tf.int32, tf.int32, tf.int32], name='preprocessor')
##predictions
sess = tf.Session(graph=restored_graph)
print(trX[:4])
y = tf.print(y, summarize=-1)
y_out = sess.run(y, feed_dict={
x: trX[:4].reshape((-1,1))
})
print(y_out)