AhmedSSabir's picture
Update README.md
19be54f
|
raw
history blame
2.32 kB

Visual semantic with BERT-CNN

To take advantage of the overlapping between the visual context and the caption, and to extract global information from each visual, we use BERT as an embedding layer followed by a shallow CNN (tri-gram kernel) (Kim,204).

This model can be used to assign an object-to-caption relatedness score, which is valuable for (1) caption diverse re-ranking, and (2) generate soft labels for caption filtering when scraping text-to-captions from the internet.

The model is trained with a strict filter of 0.4 similarity distance thresholds between the object and its related caption.

For a quick start please have a look at this colab

For the dataset

conda create -n BERT_visual python=3.6 anaconda
conda activate BERT_visual
pip install tensorflow==1.15.0
pip install --upgrade tensorflow_hub==0.7.0
git clone https://github.com/gaphex/bert_experimental/
import tensorflow as tf
import numpy as np
import pandas as pd
import sys
from sklearn.model_selection import train_test_split

sys.path.insert(0, "bert_experimental")

from bert_experimental.finetuning.text_preprocessing import build_preprocessor
from bert_experimental.finetuning.graph_ops import load_graph

df = pd.read_csv("test.tsv", sep='\t')

texts = []
delimiter = " ||| "

for vis, cap  in zip(df.visual.tolist(), df.caption.tolist()):
  texts.append(delimiter.join((str(vis), str(cap))))

texts = np.array(texts)

trX, tsX = train_test_split(texts, shuffle=False, test_size=0.01)

restored_graph = load_graph("frozen_graph.pb")

graph_ops = restored_graph.get_operations()
input_op, output_op = graph_ops[0].name, graph_ops[-1].name
print(input_op, output_op)

x = restored_graph.get_tensor_by_name(input_op + ':0')
y = restored_graph.get_tensor_by_name(output_op + ':0')

preprocessor = build_preprocessor("vocab.txt", 64)
py_func = tf.numpy_function(preprocessor, [x], [tf.int32, tf.int32, tf.int32], name='preprocessor')

##predictions
sess = tf.Session(graph=restored_graph)

print(trX[:4])

y = tf.print(y, summarize=-1)
y_out = sess.run(y, feed_dict={
        x: trX[:4].reshape((-1,1))
     
    })


print(y_out)